Skip to content
Snippets Groups Projects
Commit ae19b986 authored by ko_oler's avatar ko_oler
Browse files

Добавлена возможность авторизации пользователя с разными Identity с одинаковым email

parent 7432bd15
No related branches found
No related tags found
No related merge requests found
Showing
with 293 additions and 25 deletions
......@@ -61,7 +61,7 @@ func TestExample(t *testing.T) {
Once()
usersService := &usersmocks.Users{}
usersService.On("GetByIdentity", mock.Anything, "74d90aaf").Return(user, nil).Once()
usersService.On("Login", mock.Anything, "74d90aaf", mock.Anything).Return(user, nil).Once()
factory := auth.PrincipalFactory{Users: usersService}
......
......@@ -25,3 +25,17 @@ func WithPrincipal(ctx context.Context, p Principal) context.Context {
func WithSystem(ctx context.Context) context.Context {
return WithPrincipal(ctx, &SystemPrincipal{})
}
type authToken struct{}
func GetAuthToken(ctx context.Context) string {
t, _ := ctx.Value(authToken{}).(string)
return t
}
func WithAuthToken(ctx context.Context, token string) context.Context {
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, authToken{}, token)
}
package auth
import (
"fmt"
"strings"
"git.perx.ru/perxis/perxis-go/pkg/clients"
"git.perx.ru/perxis/perxis-go/pkg/collaborators"
"git.perx.ru/perxis/perxis-go/pkg/environments"
"git.perx.ru/perxis/perxis-go/pkg/errors"
"git.perx.ru/perxis/perxis-go/pkg/members"
"git.perx.ru/perxis/perxis-go/pkg/roles"
"git.perx.ru/perxis/perxis-go/pkg/spaces"
"git.perx.ru/perxis/perxis-go/pkg/users"
"github.com/golang-jwt/jwt/v5"
)
type PrincipalFactory struct {
......@@ -22,9 +25,31 @@ type PrincipalFactory struct {
environments.Environments
}
func (f PrincipalFactory) User(identity string) Principal {
return &UserPrincipal{
identity: identity,
func getValueFromToken(tokenString, name string) (string, error) {
var value string
t := strings.Split(tokenString, "Bearer ")
if len(t) == 2 { //nolint:mnd //not mnd
tokenString = t[1]
}
// Используем ParseUnverified, так как считаем, что токен был уже проверен до получения
token, _, err := new(jwt.Parser).ParseUnverified(tokenString, jwt.MapClaims{})
if err != nil {
return "", err
}
if claims, ok := token.Claims.(jwt.MapClaims); ok {
value = fmt.Sprint(claims[name])
}
if value == "" {
return "", errors.New("invalid token payload")
}
return value, nil
}
func (f PrincipalFactory) User(identity ...string) Principal {
p := &UserPrincipal{
identity: identity[0],
users: f.Users,
members: f.Members,
roles: f.Roles,
......@@ -32,6 +57,12 @@ func (f PrincipalFactory) User(identity string) Principal {
spaces: f.Spaces,
environments: f.Environments,
}
if len(identity) > 1 {
p.email = identity[1]
}
return p
}
func (f PrincipalFactory) Client(param *clients.GetByParams) Principal {
......@@ -65,6 +96,11 @@ func (f PrincipalFactory) Principal(principalId string) Principal {
return f.Client(&clients.GetByParams{OAuthClientID: strings.TrimSuffix(principalId, "@clients")})
case strings.HasPrefix(principalId, "API-Key"):
return f.Client(&clients.GetByParams{APIKey: strings.TrimPrefix(principalId, "API-Key ")})
case strings.HasPrefix(principalId, "Bearer "):
var email string
email, _ = getValueFromToken(principalId, "email")
principalId, _ = getValueFromToken(principalId, "sub")
return f.User(principalId, email)
default:
return f.User(principalId)
}
......
package auth
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_getValueFromToken(t *testing.T) {
tests := []struct {
name string
tokenString string
field string
want string
wantErr bool
}{
{
"With Bearer",
"Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c3JfaWRlbnRfMiIsImVtYWlsIjoidGVzdEB" +
"0ZXN0LnJ1IiwiaWF0IjoxNTE2MjM5MDIyfQ.MLo310mkPmZdJlIRo3POhevFwd-O_UyxE-1opbQMVVs",
"email",
"test@test.ru",
false,
},
{
"Without Bearer",
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c3JfaWRlbnRfMiIsImVtYWlsIjoidGVzdEB0ZXN0Ln" +
"J1IiwiaWF0IjoxNTE2MjM5MDIyfQ.MLo310mkPmZdJlIRo3POhevFwd-O_UyxE-1opbQMVVs",
"email",
"test@test.ru",
false,
},
{
"Sub",
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c3JfaWRlbnRfMiIsImVtYWlsIjoidGVzdEB0ZXN0Ln" +
"J1IiwiaWF0IjoxNTE2MjM5MDIyfQ.MLo310mkPmZdJlIRo3POhevFwd-O_UyxE-1opbQMVVs",
"sub",
"usr_ident_2",
false,
},
{
"Invalid token",
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.zdWIiOiJ1c3JfaWRlbnRfMiIsImVtYWlsIjoidGVzdEB0ZXN0LnJ1I" +
"iwiaWF0IjoxNTE2MjM5MDIyfQ.MLo310mkPmZdJlIRo3POhevFwd-O_UyxE-1opbQMVVs",
"email",
"test@test.ru",
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := getValueFromToken(tt.tokenString, tt.field)
if !tt.wantErr {
require.NoError(t, err)
assert.Equal(t, tt.want, got)
} else {
assert.Error(t, err)
}
})
}
}
......@@ -16,7 +16,6 @@ import (
)
const (
OAuth2IdentityMetadata = "x-perxis-identity"
TLSIdentityMetadata = "x-forwarded-client-cert"
AccessMetadata = "x-perxis-access"
......@@ -28,15 +27,9 @@ func GRPCToContext(factory *PrincipalFactory) kitgrpc.ServerRequestFunc {
if identity := md.Get(TLSIdentityMetadata); len(identity) > 0 {
return WithPrincipal(ctx, factory.Principal(identity[0]))
}
if identity := md.Get(OAuth2IdentityMetadata); len(identity) > 0 {
return WithPrincipal(ctx, factory.Principal(identity[0]))
if token := md.Get(AuthorizationMetadata); len(token) > 0 {
return WithPrincipal(WithAuthToken(ctx, token[0]), factory.Principal(token[0]))
}
if identity := md.Get(AuthorizationMetadata); len(identity) > 0 {
return WithPrincipal(ctx, factory.Principal(identity[0]))
}
if access := md.Get(AccessMetadata); len(access) > 0 {
return WithPrincipal(ctx, factory.System())
}
......@@ -51,19 +44,19 @@ func ContextToGRPC() kitgrpc.ClientRequestFunc {
switch p := p.(type) {
case *UserPrincipal:
if p.GetIdentity(ctx) != "" {
(*md)[OAuth2IdentityMetadata] = []string{p.GetIdentity(ctx)}
(*md)[AuthorizationMetadata] = []string{p.GetIdentity(ctx)}
if GetAuthToken(ctx) != "" {
(*md)[AuthorizationMetadata] = []string{GetAuthToken(ctx)}
}
case *ClientPrincipal:
if ident := p.GetIdentity(ctx); ident != nil {
switch {
case ident.OAuthClientID != "":
(*md)[OAuth2IdentityMetadata] = []string{ident.OAuthClientID + "@clients"}
(*md)[AuthorizationMetadata] = []string{ident.OAuthClientID + "@clients"}
case ident.TLSSubject != "":
(*md)[TLSIdentityMetadata] = []string{ident.TLSSubject}
case ident.APIKey != "":
(*md)[AuthorizationMetadata] = []string{"API-Key " + ident.APIKey}
}
}
case *SystemPrincipal:
......
......@@ -18,6 +18,7 @@ import (
type UserPrincipal struct {
id string
identity string
email string
user *users.User
invalid bool
......@@ -128,8 +129,7 @@ func (u *UserPrincipal) User(ctx context.Context) *users.User {
case u.id != "":
user, err = u.users.Get(WithSystem(ctx), u.id)
case u.identity != "":
ctx = WithSystem(ctx)
user, err = u.users.GetByIdentity(WithSystem(ctx), u.identity)
user, err = u.users.Login(WithSystem(ctx), u.identity, u.email)
}
if err != nil || user == nil {
......
......@@ -128,6 +128,26 @@ func (m *accessLoggingMiddleware) GetByIdentity(ctx context.Context, identity st
return user, err
}
func (m *accessLoggingMiddleware) Login(ctx context.Context, identity string, email string) (user *users.User, err error) {
begin := time.Now()
m.logger.Debug("Login.Request",
zap.Reflect("principal", auth.GetPrincipal(ctx)),
zap.Reflect("identity", identity),
zap.Reflect("email", email),
)
user, err = m.next.Login(ctx, identity, email)
m.logger.Debug("Login.Response",
zap.Duration("time", time.Since(begin)),
zap.Reflect("user", user),
zap.Error(err),
)
return user, err
}
func (m *accessLoggingMiddleware) Update(ctx context.Context, user *users.User) (err error) {
begin := time.Now()
......
......@@ -27,7 +27,6 @@ func (m cachingMiddleware) Create(ctx context.Context, create *service.User) (us
}
func (m cachingMiddleware) Get(ctx context.Context, id string) (user *service.User, err error) {
value, e := m.cache.Get(id)
if e == nil {
return value.(*service.User).Clone(), nil
......@@ -48,7 +47,6 @@ func (m cachingMiddleware) Find(ctx context.Context, filter *service.Filter, opt
}
func (m cachingMiddleware) Update(ctx context.Context, update *service.User) (err error) {
err = m.next.Update(ctx, update)
value, e := m.cache.Get(update.ID)
if err == nil && e == nil {
......@@ -62,7 +60,6 @@ func (m cachingMiddleware) Update(ctx context.Context, update *service.User) (er
}
func (m cachingMiddleware) Delete(ctx context.Context, id string) (err error) {
err = m.next.Delete(ctx, id)
value, e := m.cache.Get(id)
if err == nil && e == nil {
......@@ -76,7 +73,6 @@ func (m cachingMiddleware) Delete(ctx context.Context, id string) (err error) {
}
func (m cachingMiddleware) GetByIdentity(ctx context.Context, identity string) (user *service.User, err error) {
value, e := m.cache.Get(identity)
if e == nil {
return value.(*service.User).Clone(), nil
......@@ -91,3 +87,20 @@ func (m cachingMiddleware) GetByIdentity(ctx context.Context, identity string) (
}
return nil, err
}
//nolint:nonamedreturns //generated
func (m cachingMiddleware) Login(ctx context.Context, identity string, email string) (user *service.User, err error) {
value, e := m.cache.Get(identity)
if e == nil {
return value.(*service.User).Clone(), nil //nolint:errcheck //generated
}
user, err = m.next.Login(ctx, identity, email)
if err == nil {
_ = m.cache.Set(user.ID, user)
for _, i := range user.Identities {
_ = m.cache.Set(i, user)
}
return user.Clone(), nil
}
return nil, err
}
......@@ -80,6 +80,16 @@ func (m *errorLoggingMiddleware) GetByIdentity(ctx context.Context, identity str
return m.next.GetByIdentity(ctx, identity)
}
func (m *errorLoggingMiddleware) Login(ctx context.Context, identity string, email string) (user *users.User, err error) {
logger := m.logger
defer func() {
if err != nil {
logger.Warn("response error", zap.Error(err))
}
}()
return m.next.Login(ctx, identity, email)
}
func (m *errorLoggingMiddleware) Update(ctx context.Context, user *users.User) (err error) {
logger := m.logger
defer func() {
......
......@@ -119,3 +119,17 @@ func (m *loggingMiddleware) GetByIdentity(ctx context.Context, identity string)
}
return user, err
}
//nolint:nonamedreturns //generated
func (m *loggingMiddleware) Login(ctx context.Context, identity string, email string) (user *users.User, err error) {
logger := m.logger.With(
logzap.Caller(ctx),
logzap.Object(pkgId.NewUserId(identity)),
)
user, err = m.next.Login(ctx, identity, email)
if err != nil {
logger.Error("Failed to login", zap.Error(err))
}
return user, err
}
......@@ -91,6 +91,18 @@ func (m *recoveringMiddleware) GetByIdentity(ctx context.Context, identity strin
return m.next.GetByIdentity(ctx, identity)
}
func (m *recoveringMiddleware) Login(ctx context.Context, identity string, email string) (user *users.User, err error) {
logger := m.logger
defer func() {
if r := recover(); r != nil {
logger.Error("panic", zap.Error(fmt.Errorf("%v", r)))
err = fmt.Errorf("%v", r)
}
}()
return m.next.Login(ctx, identity, email)
}
func (m *recoveringMiddleware) Update(ctx context.Context, user *users.User) (err error) {
logger := m.logger
defer func() {
......
......@@ -254,6 +254,47 @@ func (_d telemetryMiddleware) GetByIdentity(ctx context.Context, identity string
return user, err
}
// Login implements users.Users
func (_d telemetryMiddleware) Login(ctx context.Context, identity string, email string) (user *users.User, err error) {
var att = []attribute.KeyValue{
attribute.String("service", "Users"),
attribute.String("method", "Login"),
}
attributes := otelmetric.WithAttributeSet(attribute.NewSet(att...))
start := time.Now()
ctx, _span := otel.Tracer(_d._instance).Start(ctx, "Users.Login")
defer _span.End()
user, err = _d.Users.Login(ctx, identity, email)
_d.requestMetrics.DurationMilliseconds.Record(ctx, time.Since(start).Milliseconds(), attributes)
caller, _ := pkgId.NewObjectId(auth.GetPrincipal(ctx))
if caller != nil {
att = append(att, attribute.String("caller", caller.String()))
}
_d.requestMetrics.Total.Add(ctx, 1, otelmetric.WithAttributeSet(attribute.NewSet(att...)))
if _d._spanDecorator != nil {
_d._spanDecorator(_span, map[string]interface{}{
"ctx": ctx,
"identity": identity,
"email": email}, map[string]interface{}{
"user": user,
"err": err})
} else if err != nil {
_d.requestMetrics.FailedTotal.Add(ctx, 1, attributes)
_span.RecordError(err)
_span.SetAttributes(attribute.String("event", "error"))
_span.SetAttributes(attribute.String("message", err.Error()))
}
return user, err
}
// Update implements users.Users
func (_d telemetryMiddleware) Update(ctx context.Context, user *users.User) (err error) {
var att = []attribute.KeyValue{
......
......@@ -161,6 +161,36 @@ func (_m *Users) GetByIdentity(ctx context.Context, identity string) (*users.Use
return r0, r1
}
// Login provides a mock function with given fields: ctx, identity, email
func (_m *Users) Login(ctx context.Context, identity string, email string) (*users.User, error) {
ret := _m.Called(ctx, identity, email)
if len(ret) == 0 {
panic("no return value specified for Login")
}
var r0 *users.User
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) (*users.User, error)); ok {
return rf(ctx, identity, email)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string) *users.User); ok {
r0 = rf(ctx, identity, email)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*users.User)
}
}
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
r1 = rf(ctx, identity, email)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Update provides a mock function with given fields: ctx, user
func (_m *Users) Update(ctx context.Context, user *users.User) error {
ret := _m.Called(ctx, user)
......
......@@ -16,6 +16,7 @@ type Users interface {
Update(ctx context.Context, user *User) (err error)
Delete(ctx context.Context, id string) (err error)
GetByIdentity(ctx context.Context, identity string) (user *User, err error)
Login(ctx context.Context, identity string, email string) (user *User, err error)
}
type Filter struct {
......
......@@ -65,3 +65,12 @@ func (set EndpointsSet) GetByIdentity(arg0 context.Context, arg1 string) (res0 *
}
return response.(*GetByIdentityResponse).User, res1
}
func (set EndpointsSet) Login(arg0 context.Context, arg1 string, arg2 string) (res0 *users.User, res1 error) {
request := LoginRequest{Identity: arg1, Email: arg2}
response, res1 := set.LoginEndpoint(arg0, &request)
if res1 != nil {
return
}
return response.(*LoginResponse).User, res1
}
......@@ -12,4 +12,5 @@ type EndpointsSet struct {
UpdateEndpoint endpoint.Endpoint
DeleteEndpoint endpoint.Endpoint
GetByIdentityEndpoint endpoint.Endpoint
LoginEndpoint endpoint.Endpoint
}
......@@ -49,4 +49,12 @@ type (
GetByIdentityResponse struct {
User *users.User `json:"user"`
}
LoginRequest struct {
Identity string `json:"identity"`
Email string `json:"email"`
}
LoginResponse struct {
User *users.User `json:"user"`
}
)
......@@ -18,5 +18,6 @@ func NewClient(conn *grpc.ClientConn, opts ...grpckit.ClientOption) transport.En
GetByIdentityEndpoint: grpcerr.ClientMiddleware(c.GetByIdentityEndpoint),
GetEndpoint: grpcerr.ClientMiddleware(c.GetEndpoint),
UpdateEndpoint: grpcerr.ClientMiddleware(c.UpdateEndpoint),
LoginEndpoint: grpcerr.ClientMiddleware(c.LoginEndpoint),
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment