package items

import (
	"context"

	"git.perx.ru/perxis/perxis-go/pkg/data"
	"git.perx.ru/perxis/perxis-go/pkg/errors"
	"git.perx.ru/perxis/perxis-go/pkg/options"
	"google.golang.org/grpc/codes"
)

type BatchProcessor struct {
	Items                        Items
	SpaceID, EnvID, CollectionID string
	FindOptions                  *FindOptions
	FindPublishedOptions         *FindPublishedOptions
	Filter                       *Filter

	pageSize, pageNum int
	sort              []string
	processed         int
}

func (b *BatchProcessor) getBatch(ctx context.Context) ([]*Item, bool, error) {
	var res []*Item
	var err error
	var total int

	if b.FindPublishedOptions != nil {
		res, total, err = b.Items.FindPublished(
			ctx,
			b.SpaceID,
			b.EnvID,
			b.CollectionID,
			b.Filter,
			&FindPublishedOptions{
				Regular:     b.FindPublishedOptions.Regular,
				Hidden:      b.FindPublishedOptions.Hidden,
				Templates:   b.FindPublishedOptions.Templates,
				FindOptions: *options.NewFindOptions(b.pageNum, b.pageSize, b.sort...),
			},
		)
	} else {
		res, total, err = b.Items.Find(
			ctx,
			b.SpaceID,
			b.EnvID,
			b.CollectionID,
			b.Filter,
			&FindOptions{
				Deleted:     b.FindOptions.Deleted,
				Regular:     b.FindOptions.Regular,
				Hidden:      b.FindOptions.Hidden,
				Templates:   b.FindOptions.Templates,
				FindOptions: *options.NewFindOptions(b.pageNum, b.pageSize, b.sort...),
			},
		)
	}

	if err == nil {
		b.processed += len(res)
		b.pageNum++
	}

	return res, b.processed != total, err
}

func (b *BatchProcessor) next(ctx context.Context) (res []*Item, next bool, err error) {

	for {
		res, next, err = b.getBatch(ctx)
		if err != nil {
			if errors.GetStatusCode(err) == codes.ResourceExhausted && b.reducePageSize() {
				continue
			}

			return nil, false, err
		}

		break
	}

	return res, next, nil
}

func (b *BatchProcessor) reducePageSize() bool {
	if b.pageSize == 1 {
		return false
	}

	b.pageNum = 2 * b.pageNum
	b.pageSize = b.pageSize / 2

	return true
}

func (b *BatchProcessor) Do(ctx context.Context, f func(batch []*Item) error) (int, error) {

	if b.FindOptions == nil && b.FindPublishedOptions == nil {
		b.FindOptions = new(FindOptions)
	}
	if b.FindOptions != nil {
		b.pageSize = b.FindOptions.PageSize
		b.sort = b.FindOptions.Sort
	}
	if b.FindPublishedOptions != nil {
		b.pageSize = b.FindPublishedOptions.PageSize
		b.sort = b.FindPublishedOptions.Sort
	}

	if b.pageSize == 0 {
		b.pageSize = 128
	}

	if b.Filter != nil && (len(b.Filter.ID) > 0 || len(b.Filter.Q) > 0) && !data.Contains("_id", b.sort) {
		b.sort = append(b.sort, "_id")
	}

	var err error

	next := true
	for next {

		var batch []*Item

		batch, next, err = b.next(ctx)
		if err != nil {
			return 0, err
		}

		if err = f(batch); err != nil {
			return 0, err
		}
	}
	return b.processed, nil
}
