perxis.interceptors

 1import grpc
 2import collections
 3
 4
 5class _GenericClientInterceptor(
 6    grpc.aio.UnaryUnaryClientInterceptor,
 7    grpc.aio.UnaryStreamClientInterceptor,
 8    grpc.aio.StreamUnaryClientInterceptor,
 9    grpc.aio.StreamStreamClientInterceptor,
10):
11    def __init__(self, interceptor_function):
12        self._fn = interceptor_function
13
14    async def intercept_unary_unary(self, continuation, client_call_details, request):
15        new_details, new_request_iterator, postprocess = self._fn(
16            client_call_details, iter((request,)), False, False
17        )
18        response = await continuation(new_details, next(new_request_iterator))
19        return postprocess(response) if postprocess else response
20
21    async def intercept_unary_stream(self, continuation, client_call_details, request):
22        new_details, new_request_iterator, postprocess = self._fn(
23            client_call_details, iter((request,)), False, True
24        )
25        response_it = continuation(new_details, next(new_request_iterator))
26        return postprocess(response_it) if postprocess else response_it
27
28    async def intercept_stream_unary(
29        self, continuation, client_call_details, request_iterator
30    ):
31        new_details, new_request_iterator, postprocess = self._fn(
32            client_call_details, request_iterator, True, False
33        )
34        response = await continuation(new_details, new_request_iterator)
35        return postprocess(response) if postprocess else response
36
37    async def intercept_stream_stream(
38        self, continuation, client_call_details, request_iterator
39    ):
40        new_details, new_request_iterator, postprocess = self._fn(
41            client_call_details, request_iterator, True, True
42        )
43        response_it = continuation(new_details, new_request_iterator)
44        return postprocess(response_it) if postprocess else response_it
45
46
47def create(intercept_call):
48    return _GenericClientInterceptor(intercept_call)
49
50
51class _ClientCallDetails(
52    collections.namedtuple(
53        typename="_ClientCallDetails",
54        field_names=("method", "timeout", "metadata", "credentials", "wait_for_ready"),
55    ),
56    grpc.ClientCallDetails,
57):
58    pass
59
60
61def header_adder_interceptor(header, value):
62    def intercept_call(
63        client_call_details, request_iterator, request_streaming, response_streaming
64    ):
65        metadata = []
66        if client_call_details.metadata is not None:
67            metadata = list(client_call_details.metadata)
68        metadata.append(
69            (
70                header,
71                value,
72            )
73        )
74
75        client_call_details = _ClientCallDetails(
76            client_call_details.method,
77            client_call_details.timeout,
78            metadata,
79            client_call_details.credentials,
80            client_call_details.wait_for_ready,
81        )
82        return client_call_details, request_iterator, None
83
84    return create(intercept_call)
def create(intercept_call):
48def create(intercept_call):
49    return _GenericClientInterceptor(intercept_call)
def header_adder_interceptor(header, value):
62def header_adder_interceptor(header, value):
63    def intercept_call(
64        client_call_details, request_iterator, request_streaming, response_streaming
65    ):
66        metadata = []
67        if client_call_details.metadata is not None:
68            metadata = list(client_call_details.metadata)
69        metadata.append(
70            (
71                header,
72                value,
73            )
74        )
75
76        client_call_details = _ClientCallDetails(
77            client_call_details.method,
78            client_call_details.timeout,
79            metadata,
80            client_call_details.credentials,
81            client_call_details.wait_for_ready,
82        )
83        return client_call_details, request_iterator, None
84
85    return create(intercept_call)