package urlsigner

import (
	"bytes"
	"crypto/sha256"
	"encoding/base64"
	"net/url"
	"strconv"
	"strings"
	"time"

	"git.perx.ru/perxis/perxis-go/pkg/data"
)

type URLSigner interface {
	Sign(url *url.URL) *url.URL
	Check(url *url.URL) bool
}

const (
	defaultSignatureExpire = 15 * time.Minute
	defaultQueryKey        = "sign"

	separator = "|"
	saltSize  = 16
)

type urlSigner struct {
	secret         string
	expirationTime time.Duration
	queryKey       string
	params         []string
}

func NewURLSigner(secret string, expirationTime time.Duration, queryKey string, params ...string) URLSigner {
	if len(params) == 0 {
		params = make([]string, 0)
	}
	if expirationTime == 0 {
		expirationTime = defaultSignatureExpire
	}
	if queryKey == "" {
		queryKey = defaultQueryKey
	}
	return &urlSigner{
		secret:         secret,
		expirationTime: expirationTime,
		queryKey:       queryKey,
		params:         params,
	}
}

func (s *urlSigner) Sign(u *url.URL) *url.URL {

	q := u.Query()

	h := sha256.New()
	salt := data.GenerateRandomString(saltSize)
	for _, p := range s.params {
		if vv := q[p]; len(vv) > 0 {
			for _, v := range vv {
				h.Write([]byte(v))
			}
		}
	}

	h.Write([]byte(u.Path))
	h.Write([]byte(s.expirationTime.String()))
	h.Write([]byte(s.secret))
	h.Write([]byte(salt))

	expTime := time.Now().Add(s.expirationTime).Unix()
	res := strings.Join([]string{strconv.FormatInt(expTime, 16), salt, string(h.Sum(nil))}, separator)

	q.Set(s.queryKey, base64.URLEncoding.EncodeToString([]byte(res)))
	u.RawQuery = q.Encode()
	return u
}

func (s *urlSigner) Check(u *url.URL) bool {

	q := u.Query()
	sign := q.Get(s.queryKey)
	if sign == "" {
		return false
	}

	b, err := base64.URLEncoding.DecodeString(sign)
	if err != nil {
		return false
	}

	m := bytes.Split(b, []byte(separator))
	if len(m) < 3 {
		return false
	}

	expTime, err := strconv.ParseInt(string(m[0]), 16, 64)
	if err != nil || time.Now().Unix() > expTime {
		return false
	}

	salt := m[1]
	var hash []byte
	for i := 2; i < len(m); i++ {
		hash = append(hash, m[i]...)
		if len(m) > i+1 {
			hash = append(hash, []byte(separator)...)
		}
	}

	h := sha256.New()
	for _, p := range s.params {
		if vv := q[p]; len(vv) > 0 {
			for _, v := range vv {
				h.Write([]byte(v))
			}
		}
	}
	h.Write([]byte(u.Path))
	h.Write([]byte(s.expirationTime.String()))
	h.Write([]byte(s.secret))
	h.Write(salt)

	return bytes.Equal(hash, h.Sum(nil))
}
