From 90c5e5e65f8fffdc914763e140e0a3166481d43d Mon Sep 17 00:00:00 2001 From: ko_oler <kooler89@gmail.com> Date: Wed, 12 Feb 2025 15:52:21 +0300 Subject: [PATCH] =?UTF-8?q?=D0=B4=D0=BE=D0=B1=D0=B0=D0=B2=D0=BB=D0=B5?= =?UTF-8?q?=D0=BD=20PrincipalInterceptorConnect?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/auth/grpc.go | 60 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/pkg/auth/grpc.go b/pkg/auth/grpc.go index a947a33b..d994cb88 100644 --- a/pkg/auth/grpc.go +++ b/pkg/auth/grpc.go @@ -6,6 +6,7 @@ import ( "crypto/x509" "net/url" + "connectrpc.com/connect" "git.perx.ru/perxis/perxis-go/pkg/errors" kitgrpc "github.com/go-kit/kit/transport/grpc" "golang.org/x/oauth2/clientcredentials" @@ -150,3 +151,62 @@ func TLSCredentials(ctx context.Context, cert, cacert, key []byte) (credentials. } return credentials.NewTLS(&tls.Config{Certificates: []tls.Certificate{clientCert}, RootCAs: certPool}), nil } + +// PrincipalInterceptorConnect интерсептор для клиента и сервера +// используется для получения данных принципала из запроса и добавления в контекст. +func PrincipalInterceptorConnect(factory *PrincipalFactory) connect.UnaryInterceptorFunc { //nolint:gocognit // example + interceptor := func(next connect.UnaryFunc) connect.UnaryFunc { + return func( + ctx context.Context, + req connect.AnyRequest, + ) (connect.AnyResponse, error) { + if req.Spec().IsClient { + p := GetPrincipal(ctx) + switch p := p.(type) { + case *UserPrincipal: + if p.GetIdentity(ctx) != "" { + req.Header().Set(OAuth2IdentityMetadata, p.GetIdentity(ctx)) + } + case *ClientPrincipal: + if ident := p.GetIdentity(ctx); ident != nil { + switch { + case ident.OAuthClientID != "": + req.Header().Set(OAuth2IdentityMetadata, ident.OAuthClientID+"@clients") + case ident.TLSSubject != "": + req.Header().Set(TLSIdentityMetadata, ident.TLSSubject) + case ident.APIKey != "": + req.Header().Set(AuthorizationMetadata, "API-Key "+ident.APIKey) + } + } + case *SystemPrincipal: + req.Header().Set(AccessMetadata, p.GetID(ctx)) + } + return next(ctx, req) + } + if identity := req.Header().Get(TLSIdentityMetadata); identity != "" { + ctx = WithPrincipal(ctx, factory.Principal(identity)) + return next(ctx, req) + } + + if identity := req.Header().Get(OAuth2IdentityMetadata); identity != "" { + ctx = WithPrincipal(ctx, factory.Principal(identity)) + return next(ctx, req) + } + + if identity := req.Header().Get(AuthorizationMetadata); identity != "" { + ctx = WithPrincipal(ctx, factory.Principal(identity)) + return next(ctx, req) + } + + if access := req.Header().Get(AccessMetadata); access != "" { + ctx = WithPrincipal(ctx, factory.Principal(access)) + return next(ctx, req) + } + + ctx = WithPrincipal(ctx, factory.Anonymous()) + return next(ctx, req) + } + } + + return interceptor +} -- GitLab