From a79d2df27dc4b1b085030be86d975fbc0273054f Mon Sep 17 00:00:00 2001 From: Alena Petraki <alena.petraki@gmail.com> Date: Wed, 11 Oct 2023 18:08:07 +0300 Subject: [PATCH] =?UTF-8?q?=D0=94=D0=BE=D1=80=D0=B0=D0=B1=D0=BE=D1=82?= =?UTF-8?q?=D0=BA=D0=B8=20=D1=81=D0=B5=D1=80=D0=B2=D0=B5=D1=80=D0=B0=20?= =?UTF-8?q?=D0=B4=D0=BB=D1=8F=20=D0=BF=D0=BE=D0=B4=D0=B4=D0=B5=D1=80=D0=B6?= =?UTF-8?q?=D0=BA=D0=B8=20=D0=B5=D0=B3=D0=BE=20=D0=B8=D1=81=D0=BF=D0=BE?= =?UTF-8?q?=D0=BB=D1=8C=D0=B7=D0=BE=D0=B2=D0=B0=D0=BD=D0=B8=D1=8F=20=D0=BC?= =?UTF-8?q?=D0=B5=D0=BD=D0=B5=D0=B4=D0=B6=D0=B5=D1=80=D0=BE=D0=BC=20=D1=80?= =?UTF-8?q?=D0=B0=D1=81=D1=88=D0=B8=D1=80=D0=B5=D0=BD=D0=B8=D0=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/extension/server.go | 174 +++++++++++++++++------------------ pkg/extension/server_test.go | 4 +- 2 files changed, 89 insertions(+), 89 deletions(-) diff --git a/pkg/extension/server.go b/pkg/extension/server.go index b69b10a3..4f8f3bae 100644 --- a/pkg/extension/server.go +++ b/pkg/extension/server.go @@ -11,49 +11,101 @@ import ( "google.golang.org/protobuf/proto" ) +type RouteFn func(ctx context.Context, extensions ...string) ([]Extension, error) + +type WrapErrFn func(extension Extension, err error) error + +func DefaultWrapErrFn() WrapErrFn { + return func(_ Extension, err error) error { return err } +} + +type ServerOption func(c *Server) + +// WrapErrs Оборачивать ошибки, возвращаемые от каждого из +// расширений, в название расширения +func WrapErrs() ServerOption { + return func(c *Server) { + c.wrapErr = func(extension Extension, err error) error { + return errors.Wrap(err, extension.GetDescriptor().Extension) + } + } +} + +// NamedExtensions Перенаправлять запросы к сервисам с названиями, соответствующими +// списку расширений в XXXRequest. Если не найден сервис с названием, возвращается ошибка +func NamedExtensions(extensions ...Extension) ServerOption { + return func(c *Server) { + m := make(map[string]Extension, len(extensions)) + for _, e := range extensions { + m[e.GetDescriptor().Extension] = e + } + c.extensions = func(_ context.Context, extensions ...string) ([]Extension, error) { + var res []Extension + if len(extensions) == 0 { + for _, e := range m { + res = append(res, e) + } + return res, nil + } + + for _, ext := range extensions { + e, ok := m[ext] + if !ok { + return nil, errors.Wrap(ErrUnknownExtension, ext) + } + res = append(res, e) + } + return res, nil + } + c.wrapErr = func(extension Extension, err error) error { + return errors.Wrap(err, extension.GetDescriptor().Extension) + } + } +} + +// SingleExtension Перенаправлять запросы к единому сервису +func SingleExtension(extension Extension) ServerOption { + return func(c *Server) { + c.extensions = func(_ context.Context, _ ...string) ([]Extension, error) { + return []Extension{extension}, nil + } + c.wrapErr = DefaultWrapErrFn() + } +} + type Server struct { - extensions map[string]Extension + extensions RouteFn + wrapErr WrapErrFn operations operation.Service pb.UnimplementedExtensionServiceServer } -func NewServer(operation operation.Service, extensions ...Extension) *Server { +func NewServer(operation operation.Service, options ...ServerOption) *Server { srv := &Server{ - extensions: make(map[string]Extension, len(extensions)), // todo: нужно как-то неявно создавать и регистрировать сервер операций - ? operations: operation, + wrapErr: DefaultWrapErrFn(), } - for _, s := range extensions { - srv.extensions[s.GetDescriptor().Extension] = s - } - return srv -} - -func (s *Server) getExtensions(_ context.Context, extensions []string) ([]Extension, error) { - var res []Extension - for _, ext := range extensions { - e, ok := s.extensions[ext] - if !ok { - return nil, errors.Wrap(ErrUnknownExtension, ext) - } - res = append(res, e) + for _, o := range options { + o(srv) } - return res, nil + + return srv } func (s *Server) Install(ctx context.Context, req *InstallRequest) (*operation.Proto, error) { - exts, err := s.getExtensions(ctx, req.Extensions) + exts, err := s.extensions(ctx, req.Extensions...) if err != nil { return nil, err } - desc := "Install extensions " + strings.Join(req.Extensions, ", ") + desc := "Install extensions: " + strings.Join(req.Extensions, ", ") op, err := s.operations.Create(ctx, desc, func(ctx context.Context) (proto.Message, error) { for _, ext := range exts { if err := ext.Install(ctx, req); err != nil { - return nil, errors.Wrap(err, ext.GetDescriptor().Extension) + return nil, s.wrapErr(ext, err) } } return nil, nil @@ -63,17 +115,17 @@ func (s *Server) Install(ctx context.Context, req *InstallRequest) (*operation.P } func (s *Server) Uninstall(ctx context.Context, req *UninstallRequest) (*operation.Proto, error) { - exts, err := s.getExtensions(ctx, req.Extensions) + exts, err := s.extensions(ctx, req.Extensions...) if err != nil { return nil, err } - desc := "Uninstall extensions " + strings.Join(req.Extensions, ", ") + desc := "Uninstall extensions: " + strings.Join(req.Extensions, ", ") op, err := s.operations.Create(ctx, desc, func(ctx context.Context) (proto.Message, error) { for _, ext := range exts { if err := ext.Uninstall(ctx, req); err != nil { - return nil, errors.Wrap(err, ext.GetDescriptor().Extension) + return nil, s.wrapErr(ext, err) } } return nil, nil @@ -83,17 +135,17 @@ func (s *Server) Uninstall(ctx context.Context, req *UninstallRequest) (*operati } func (s *Server) Check(ctx context.Context, req *CheckRequest) (*operation.Proto, error) { - exts, err := s.getExtensions(ctx, req.Extensions) + exts, err := s.extensions(ctx, req.Extensions...) if err != nil { return nil, err } - desc := "Check extensions " + strings.Join(req.Extensions, ", ") + desc := "Check extensions: " + strings.Join(req.Extensions, ", ") op, err := s.operations.Create(ctx, desc, func(ctx context.Context) (proto.Message, error) { for _, ext := range exts { if err := ext.Check(ctx, req); err != nil { - return nil, errors.Wrap(err, ext.GetDescriptor().Extension) + return nil, s.wrapErr(ext, err) } } return nil, nil @@ -115,12 +167,12 @@ func (s *Server) Action(ctx context.Context, in *pb.ActionRequest) (*pb.ActionRe return nil, errors.New("extension ID required") } - svc, ok := s.extensions[ext] - if !ok { - return nil, ErrUnknownExtension + svc, err := s.extensions(ctx, ext) + if err != nil { + return nil, err } - out, err := svc.Action(ctx, in) + out, err := svc[0].Action(ctx, in) if out == nil { out = &ActionResponse{} @@ -136,8 +188,9 @@ func (s *Server) Action(ctx context.Context, in *pb.ActionRequest) (*pb.ActionRe } func (s *Server) Start() error { + extensions, _ := s.extensions(context.Background()) var errs []error - for _, svc := range s.extensions { + for _, svc := range extensions { if r, ok := svc.(Runnable); ok { if err := r.Start(); err != nil { errs = append(errs, err) @@ -153,8 +206,9 @@ func (s *Server) Start() error { } func (s *Server) Stop() error { + extensions, _ := s.extensions(context.Background()) var errs []error - for _, svc := range s.extensions { + for _, svc := range extensions { if r, ok := svc.(Runnable); ok { if err := r.Stop(); err != nil { errs = append(errs, err) @@ -168,57 +222,3 @@ func (s *Server) Stop() error { return nil } - -// -------------- -// Попытки сделать один сервер для расширений и для менеджера -// -------------- - -type ExtensionsGetter interface { - GetInstalledExtensions(ctx context.Context, extensions ...string) ([]Extension, error) -} - -func NewMultiExtensionsGetter(svc ...Extension) ExtensionsGetter { - g := &multiExtensionsGetter{ - extensions: make(map[string]Extension, len(svc)), - } - for _, s := range svc { - g.extensions[s.GetDescriptor().Extension] = s - } - return g -} - -type multiExtensionsGetter struct { - extensions map[string]Extension -} - -func (g *multiExtensionsGetter) GetInstalledExtensions(_ context.Context, extensions ...string) ([]Extension, error) { - var res []Extension - if len(extensions) == 0 { - for _, e := range g.extensions { - res = append(res, e) - } - return res, nil - } - - for _, ext := range extensions { - e, ok := g.extensions[ext] - if !ok { - return nil, errors.Wrap(ErrUnknownExtension, ext) - } - - res = append(res, e) - } - return res, nil -} - -func NewMonoExtensionsGetter(svc Extension) ExtensionsGetter { - return &monoExtensionsGetter{extension: svc} -} - -type monoExtensionsGetter struct { - extension Extension -} - -func (g *monoExtensionsGetter) GetInstalledExtensions(_ context.Context, _ ...string) ([]Extension, error) { - return []Extension{g.extension}, nil -} diff --git a/pkg/extension/server_test.go b/pkg/extension/server_test.go index 491a6071..cb6d6d3a 100644 --- a/pkg/extension/server_test.go +++ b/pkg/extension/server_test.go @@ -92,7 +92,7 @@ func TestServer_Action(t *testing.T) { EnvId: "env", }, want: nil, - wantErr: ErrUnknownExtension.Error(), + wantErr: errors.Wrap(ErrUnknownExtension, "test-extension-2").Error(), }, { name: "Deprecated call, without extension", @@ -119,7 +119,7 @@ func TestServer_Action(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - srv := NewServer(operation.NewDefaultService(), tt.extension) + srv := NewServer(operation.NewDefaultService(), NamedExtensions(tt.extension)) got, err := srv.Action(context.Background(), tt.in) if tt.wantErr != "" { require.Error(t, err) -- GitLab