package events

import (
	"git.perx.ru/perxis/perxis-go/pkg/errors"
	"github.com/nats-io/nats.go"
	"github.com/nats-io/nats.go/encoders/protobuf"
	"google.golang.org/protobuf/proto"
)

type ProtoEncoder interface {
	ToProto() (proto.Message, error)
	FromProto(message proto.Message) error
}

const (
	ProtobufEncoderName = "protobuf"
)

func init() {
	nats.RegisterEncoder(ProtobufEncoderName, &ProtobufEncoder{})
}

type ProtobufEncoder struct {
	protobuf.ProtobufEncoder
}

var (
	ErrInvalidProtoMsgEncode = errors.New("events: object passed to encode must implement ProtoEncoder")
	ErrInvalidProtoMsgDecode = errors.New("events: object passed to decode must implement ProtoDecoder")
)

func (pb *ProtobufEncoder) Encode(subject string, v interface{}) ([]byte, error) {
	if v == nil {
		return nil, nil
	}
	e, ok := v.(ProtoEncoder)
	if !ok {
		return nil, ErrInvalidProtoMsgEncode
	}

	m, err := e.ToProto()
	if err != nil {
		return nil, errors.Wrap(err, "nats: encode to proto")
	}

	return pb.ProtobufEncoder.Encode(subject, m)
}

func (pb *ProtobufEncoder) Decode(subject string, data []byte, vPtr interface{}) error {

	enc, ok := vPtr.(ProtoEncoder)
	if !ok {
		return ErrInvalidProtoMsgDecode
	}

	msg, _ := enc.ToProto()

	if err := pb.ProtobufEncoder.Decode(subject, data, msg); err != nil {
		return err
	}

	return enc.FromProto(msg)
}