Skip to content

Commit

Permalink
Keep client interceptors in sync with grpc client interceptors
Browse files Browse the repository at this point in the history
  • Loading branch information
Mihir Gore committed Apr 14, 2021
1 parent e7d26a4 commit 9e8d00f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,32 +48,33 @@ def __init__(self, method, base_callable, interceptor):
self._base_callable = base_callable
self._interceptor = interceptor

def __call__(self, request, timeout=None, metadata=None, credentials=None):
def __call__(self, request, timeout=None, metadata=None, credentials=None, wait_for_ready=None,
compression=None):
def invoker(request, metadata):
return self._base_callable(request, timeout, metadata, credentials)
return self._base_callable(request, timeout, metadata, credentials, wait_for_ready, compression)

client_info = _UnaryClientInfo(self._method, timeout)
return self._interceptor.intercept_unary(
request, metadata, client_info, invoker
)

def with_call(
self, request, timeout=None, metadata=None, credentials=None
self, request, timeout=None, metadata=None, credentials=None, wait_for_ready=None, compression=None
):
def invoker(request, metadata):
return self._base_callable.with_call(
request, timeout, metadata, credentials
request, timeout, metadata, credentials, wait_for_ready, compression
)

client_info = _UnaryClientInfo(self._method, timeout)
return self._interceptor.intercept_unary(
request, metadata, client_info, invoker
)

def future(self, request, timeout=None, metadata=None, credentials=None):
def future(self, request, timeout=None, metadata=None, credentials=None, wait_for_ready=None, compression=None):
def invoker(request, metadata):
return self._base_callable.future(
request, timeout, metadata, credentials
request, timeout, metadata, credentials, wait_for_ready, compression
)

client_info = _UnaryClientInfo(self._method, timeout)
Expand All @@ -88,9 +89,9 @@ def __init__(self, method, base_callable, interceptor):
self._base_callable = base_callable
self._interceptor = interceptor

def __call__(self, request, timeout=None, metadata=None, credentials=None):
def __call__(self, request, timeout=None, metadata=None, credentials=None, wait_for_ready=None, compression=None):
def invoker(request, metadata):
return self._base_callable(request, timeout, metadata, credentials)
return self._base_callable(request, timeout, metadata, credentials, wait_for_ready, compression)

client_info = _StreamClientInfo(self._method, False, True, timeout)
return self._interceptor.intercept_stream(
Expand All @@ -105,11 +106,11 @@ def __init__(self, method, base_callable, interceptor):
self._interceptor = interceptor

def __call__(
self, request_iterator, timeout=None, metadata=None, credentials=None
self, request_iterator, timeout=None, metadata=None, credentials=None, wait_for_ready=None, compression=None
):
def invoker(request_iterator, metadata):
return self._base_callable(
request_iterator, timeout, metadata, credentials
request_iterator, timeout, metadata, credentials, wait_for_ready, compression
)

client_info = _StreamClientInfo(self._method, True, False, timeout)
Expand All @@ -118,11 +119,11 @@ def invoker(request_iterator, metadata):
)

def with_call(
self, request_iterator, timeout=None, metadata=None, credentials=None
self, request_iterator, timeout=None, metadata=None, credentials=None, wait_for_ready=None, compression=None
):
def invoker(request_iterator, metadata):
return self._base_callable.with_call(
request_iterator, timeout, metadata, credentials
request_iterator, timeout, metadata, credentials, wait_for_ready, compression
)

client_info = _StreamClientInfo(self._method, True, False, timeout)
Expand All @@ -131,11 +132,11 @@ def invoker(request_iterator, metadata):
)

def future(
self, request_iterator, timeout=None, metadata=None, credentials=None
self, request_iterator, timeout=None, metadata=None, credentials=None, wait_for_ready=None, compression=None
):
def invoker(request_iterator, metadata):
return self._base_callable.future(
request_iterator, timeout, metadata, credentials
request_iterator, timeout, metadata, credentials, wait_for_ready, compression
)

client_info = _StreamClientInfo(self._method, True, False, timeout)
Expand All @@ -151,11 +152,11 @@ def __init__(self, method, base_callable, interceptor):
self._interceptor = interceptor

def __call__(
self, request_iterator, timeout=None, metadata=None, credentials=None
self, request_iterator, timeout=None, metadata=None, credentials=None, wait_for_ready=None, compression=None
):
def invoker(request_iterator, metadata):
return self._base_callable(
request_iterator, timeout, metadata, credentials
request_iterator, timeout, metadata, credentials, wait_for_ready, compression
)

client_info = _StreamClientInfo(self._method, True, True, timeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,37 @@
from .protobuf.test_server_pb2 import Request


class Interceptor(
grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor,
grpc.StreamStreamClientInterceptor,
):
def intercept_unary_unary(self, continuation, client_call_details, request):
return self._intercept_call(continuation, client_call_details, request)

def intercept_unary_stream(self, continuation, client_call_details, request):
return self._intercept_call(continuation, client_call_details, request)

def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
return self._intercept_call(continuation, client_call_details, request_iterator)

def intercept_stream_stream(self, continuation, client_call_details, request_iterator):
return self._intercept_call(continuation, client_call_details, request_iterator)

def _intercept_call(self, continuation, client_call_details, request_or_iterator):
return continuation(client_call_details, request_or_iterator)



class TestClientProto(TestBase):
def setUp(self):
super().setUp()
GrpcInstrumentorClient().instrument()
self.server = create_test_server(25565)
self.server.start()
self.channel = grpc.insecure_channel("localhost:25565")
interceptors = [Interceptor()]
self.channel = grpc.insecure_channel("localhost:25565", interceptors)
self._stub = test_server_pb2_grpc.GRPCTestServerStub(self.channel)

def tearDown(self):
Expand Down

0 comments on commit 9e8d00f

Please sign in to comment.