diff --git a/pkg/auth/grpc.go b/pkg/auth/grpc.go index a947a33bab75da4c55ca9575e15d915fbc480b67..d994cb88ac01cfb829d832dcf89b3e37b2e56b00 100644 --- a/pkg/auth/grpc.go +++ b/pkg/auth/grpc.go @@ -6,6 +6,7 @@ import ( "crypto/x509" "net/url" + "connectrpc.com/connect" "git.perx.ru/perxis/perxis-go/pkg/errors" kitgrpc "github.com/go-kit/kit/transport/grpc" "golang.org/x/oauth2/clientcredentials" @@ -150,3 +151,62 @@ func TLSCredentials(ctx context.Context, cert, cacert, key []byte) (credentials. } return credentials.NewTLS(&tls.Config{Certificates: []tls.Certificate{clientCert}, RootCAs: certPool}), nil } + +// PrincipalInterceptorConnect интерсептор для клиента и сервера +// используется для получения данных принципала из запроса и добавления в контекст. +func PrincipalInterceptorConnect(factory *PrincipalFactory) connect.UnaryInterceptorFunc { //nolint:gocognit // example + interceptor := func(next connect.UnaryFunc) connect.UnaryFunc { + return func( + ctx context.Context, + req connect.AnyRequest, + ) (connect.AnyResponse, error) { + if req.Spec().IsClient { + p := GetPrincipal(ctx) + switch p := p.(type) { + case *UserPrincipal: + if p.GetIdentity(ctx) != "" { + req.Header().Set(OAuth2IdentityMetadata, p.GetIdentity(ctx)) + } + case *ClientPrincipal: + if ident := p.GetIdentity(ctx); ident != nil { + switch { + case ident.OAuthClientID != "": + req.Header().Set(OAuth2IdentityMetadata, ident.OAuthClientID+"@clients") + case ident.TLSSubject != "": + req.Header().Set(TLSIdentityMetadata, ident.TLSSubject) + case ident.APIKey != "": + req.Header().Set(AuthorizationMetadata, "API-Key "+ident.APIKey) + } + } + case *SystemPrincipal: + req.Header().Set(AccessMetadata, p.GetID(ctx)) + } + return next(ctx, req) + } + if identity := req.Header().Get(TLSIdentityMetadata); identity != "" { + ctx = WithPrincipal(ctx, factory.Principal(identity)) + return next(ctx, req) + } + + if identity := req.Header().Get(OAuth2IdentityMetadata); identity != "" { + ctx = WithPrincipal(ctx, factory.Principal(identity)) + return next(ctx, req) + } + + if identity := req.Header().Get(AuthorizationMetadata); identity != "" { + ctx = WithPrincipal(ctx, factory.Principal(identity)) + return next(ctx, req) + } + + if access := req.Header().Get(AccessMetadata); access != "" { + ctx = WithPrincipal(ctx, factory.Principal(access)) + return next(ctx, req) + } + + ctx = WithPrincipal(ctx, factory.Anonymous()) + return next(ctx, req) + } + } + + return interceptor +}