package middleware

import (
	"context"

	"git.perx.ru/perxis/perxis-go/pkg/cache"
	services "git.perx.ru/perxis/perxis-go/pkg/options"
	service "git.perx.ru/perxis/perxis-go/pkg/users"
)

func CachingMiddleware(cache *cache.Cache) Middleware {
	return func(next service.Users) service.Users {
		return &cachingMiddleware{
			cache: cache,
			next:  next,
		}
	}
}

type cachingMiddleware struct {
	cache *cache.Cache
	next  service.Users
}

func (m cachingMiddleware) Create(ctx context.Context, create *service.User) (user *service.User, err error) {
	return m.next.Create(ctx, create)
}

func (m cachingMiddleware) Get(ctx context.Context, userId string) (user *service.User, err error) {

	value, e := m.cache.Get(userId)
	if e == nil {
		return value.(*service.User), err
	}
	user, err = m.next.Get(ctx, userId)
	if err == nil {
		m.cache.Set(user.ID, user)
		for _, i := range user.Identities {
			m.cache.Set(i, user)
		}
	}
	return user, err
}

func (m cachingMiddleware) Find(ctx context.Context, filter *service.Filter, options *services.FindOptions) (users []*service.User, total int, err error) {
	return m.next.Find(ctx, filter, options)
}

func (m cachingMiddleware) Update(ctx context.Context, update *service.User) (err error) {

	err = m.next.Update(ctx, update)
	value, e := m.cache.Get(update.ID)
	if err == nil && e == nil {
		usr := value.(*service.User)
		m.cache.Remove(usr.ID)
		for _, i := range usr.Identities {
			m.cache.Remove(i)
		}
	}
	return err
}

func (m cachingMiddleware) Delete(ctx context.Context, userId string) (err error) {

	err = m.next.Delete(ctx, userId)
	value, e := m.cache.Get(userId)
	if err == nil && e == nil {
		usr := value.(*service.User)
		m.cache.Remove(usr.ID)
		for _, i := range usr.Identities {
			m.cache.Remove(i)
		}
	}
	return err
}

func (m cachingMiddleware) GetByIdentity(ctx context.Context, identity string) (user *service.User, err error) {

	value, e := m.cache.Get(identity)
	if e == nil {
		return value.(*service.User), err
	}
	user, err = m.next.GetByIdentity(ctx, identity)
	if err == nil {
		m.cache.Set(user.ID, user)
		for _, i := range user.Identities {
			m.cache.Set(i, user)
		}
	}
	return user, err
}
