Skip to content

Commit

Permalink
Add response headers and trailers to gRPC context (#4516)
Browse files Browse the repository at this point in the history
Resolves #4012

Signed-off-by: Tadaya Tsuyukubo <tadaya@ttddyy.net>
  • Loading branch information
ttddyy committed Jan 2, 2024
1 parent d904d76 commit cf8e8c7
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ public class GrpcClientObservationContext extends RequestReplySenderContext<Meta

private String authority;

private Metadata headers;

private Metadata trailers;

public GrpcClientObservationContext(Setter<Metadata> setter) {
super(setter);
}
Expand Down Expand Up @@ -98,4 +102,30 @@ public void setAuthority(String authority) {
this.authority = authority;
}

/**
* Response headers.
* @return response headers
* @since 1.13.0
*/
public Metadata getHeaders() {
return this.headers;
}

public void setHeaders(Metadata headers) {
this.headers = headers;
}

/**
* Trailers.
* @return trailers
* @since 1.13.0
*/
public Metadata getTrailers() {
return this.trailers;
}

public void setTrailers(Metadata trailers) {
this.trailers = trailers;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ public class GrpcServerObservationContext extends RequestReplyReceiverContext<Me
@Nullable
private String authority;

private Metadata headers;

private Metadata trailers;

public GrpcServerObservationContext(Getter<Metadata> getter) {
super(getter);
}
Expand Down Expand Up @@ -100,4 +104,30 @@ public void setAuthority(@Nullable String authority) {
this.authority = authority;
}

/**
* Response headers.
* @return response headers
* @since 1.13.0
*/
public Metadata getHeaders() {
return this.headers;
}

public void setHeaders(Metadata headers) {
this.headers = headers;
}

/**
* Trailers.
* @return trailers
* @since 1.13.0
*/
public Metadata getTrailers() {
return this.trailers;
}

public void setTrailers(Metadata trailers) {
this.trailers = trailers;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,26 @@ class ObservationGrpcClientCallListener<RespT> extends SimpleForwardingClientCal
}

@Override
public void onClose(Status status, Metadata metadata) {
public void onHeaders(Metadata headers) {
super.onHeaders(headers);
// Per javadoc, headers are not thread-safe. Make a copy.
Metadata headersToKeep = new Metadata();
headersToKeep.merge(headers);
GrpcClientObservationContext context = (GrpcClientObservationContext) this.observation.getContext();
context.setHeaders(headersToKeep);
}

@Override
public void onClose(Status status, Metadata trailers) {
GrpcClientObservationContext context = (GrpcClientObservationContext) this.observation.getContext();
context.setStatusCode(status.getCode());
context.setTrailers(trailers);
if (status.getCause() != null) {
observation.error(status.getCause());
}
this.observation.stop();
// We do not catch exception from the delegate. (following Brave design)
super.onClose(status, metadata);
super.onClose(status, trailers);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.grpc.Status;
import io.micrometer.core.instrument.binder.grpc.GrpcObservationDocumentation.GrpcServerEvents;
import io.micrometer.observation.Observation;
import io.micrometer.observation.Observation.Context;

/**
* A simple forwarding server call for {@link Observation}.
Expand All @@ -37,6 +38,18 @@ class ObservationGrpcServerCall<ReqT, RespT> extends SimpleForwardingServerCall<
this.observation = observation;
}

@Override
public void sendHeaders(Metadata headers) {
super.sendHeaders(headers);
Context context = this.observation.getContext();
if (context instanceof GrpcServerObservationContext) {
// Per javadoc, headers are not thread-safe. Make a copy.
Metadata headersToKeep = new Metadata();
headersToKeep.merge(headers);
((GrpcServerObservationContext) context).setHeaders(headersToKeep);
}
}

@Override
public void sendMessage(RespT message) {
this.observation.event(GrpcServerEvents.MESSAGE_SENT);
Expand All @@ -48,9 +61,9 @@ public void close(Status status, Metadata trailers) {
if (status.getCause() != null) {
this.observation.error(status.getCause());
}

GrpcServerObservationContext context = (GrpcServerObservationContext) this.observation.getContext();
context.setStatusCode(status.getCode());
context.setTrailers(trailers);
super.close(status, trailers);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,32 @@
*/
package io.micrometer.core.instrument.binder.grpc;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;

import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
import io.grpc.ForwardingServerCall.SimpleForwardingServerCall;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.MethodType;
import io.grpc.Server;
import io.grpc.ServerCall;
import io.grpc.ServerCall.Listener;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status.Code;
import io.grpc.StatusRuntimeException;
import io.grpc.inprocess.InProcessChannelBuilder;
Expand Down Expand Up @@ -50,13 +70,6 @@
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.awaitility.Awaitility.await;
Expand Down Expand Up @@ -116,11 +129,14 @@ void setUpEchoService() throws Exception {
EchoService echoService = new EchoService();
server = InProcessServerBuilder.forName("sample")
.addService(echoService)
.intercept(new ServerHeaderInterceptor())
.intercept(serverInterceptor)
.build();
server.start();

channel = InProcessChannelBuilder.forName("sample").intercept(clientInterceptor).build();
channel = InProcessChannelBuilder.forName("sample")
.intercept(new ClientHeaderInterceptor(), clientInterceptor)
.build();
}

@Test
Expand All @@ -141,13 +157,14 @@ void unaryRpc() {
GrpcServerEvents.MESSAGE_SENT);
assertThat(clientHandler.getEvents()).containsExactly(GrpcClientEvents.MESSAGE_SENT,
GrpcClientEvents.MESSAGE_RECEIVED);
verifyHeaders();
}

@Test
void unaryRpcAsync() {
SimpleServiceFutureStub stub = SimpleServiceGrpc.newFutureStub(channel);
List<String> messages = new ArrayList<>();
List<String> responses = new ArrayList<>();
List<String> responses = Collections.synchronizedList(new ArrayList<>());
List<ListenableFuture<SimpleResponse>> futures = new ArrayList<>();
int count = 40;
for (int i = 0; i < count; i++) {
Expand All @@ -171,6 +188,7 @@ public void onFailure(Throwable t) {

await().until(() -> futures.stream().allMatch(Future::isDone));
assertThat(responses).hasSize(count).containsExactlyInAnyOrderElementsOf(messages);
verifyHeaders();
}

@Test
Expand Down Expand Up @@ -210,6 +228,7 @@ void clientStreamingRpc() {
verifyServerContext("grpc.testing.SimpleService", "ClientStreamingRpc",
"grpc.testing.SimpleService/ClientStreamingRpc", MethodType.CLIENT_STREAMING);
assertThat(serverHandler.getContext().getStatusCode()).isEqualTo(Code.OK);
verifyHeaders();
}

@Test
Expand Down Expand Up @@ -241,6 +260,7 @@ void serverStreamingRpc() {
assertThat(clientHandler.getContext().getStatusCode()).isEqualTo(Code.OK);
assertThat(clientHandler.getEvents()).containsExactly(GrpcClientEvents.MESSAGE_SENT,
GrpcClientEvents.MESSAGE_RECEIVED, GrpcClientEvents.MESSAGE_RECEIVED);
verifyHeaders();
}

@Test
Expand Down Expand Up @@ -290,6 +310,7 @@ void bidiStreamingRpc() {

assertThat(serverHandler.getContext().getStatusCode()).isEqualTo(Code.OK);
assertThat(clientHandler.getContext().getStatusCode()).isEqualTo(Code.OK);
verifyHeaders();
}

private StreamObserver<SimpleResponse> createResponseObserver(List<String> messages, AtomicBoolean completed) {
Expand All @@ -312,6 +333,17 @@ public void onCompleted() {
};
}

private void verifyHeaders() {
assertThat(clientHandler.getContext().getCarrier().containsKey(ClientHeaderInterceptor.CLIENT_KEY))
.isTrue();
assertThat(clientHandler.getContext().getHeaders().containsKey(ServerHeaderInterceptor.SERVER_KEY))
.isTrue();
assertThat(serverHandler.getContext().getCarrier().containsKey(ClientHeaderInterceptor.CLIENT_KEY))
.isTrue();
assertThat(serverHandler.getContext().getHeaders().containsKey(ServerHeaderInterceptor.SERVER_KEY))
.isTrue();
}

}

@Nested
Expand All @@ -322,11 +354,14 @@ void setUpExceptionService() throws Exception {
ExceptionService exceptionService = new ExceptionService();
server = InProcessServerBuilder.forName("exception")
.addService(exceptionService)
.intercept(new ServerHeaderInterceptor())
.intercept(serverInterceptor)
.build();
server.start();

channel = InProcessChannelBuilder.forName("exception").intercept(clientInterceptor).build();
channel = InProcessChannelBuilder.forName("exception")
.intercept(new ClientHeaderInterceptor(), clientInterceptor)
.build();
}

@Test
Expand Down Expand Up @@ -574,4 +609,51 @@ List<Event> getEvents() {

}

static class ClientHeaderInterceptor implements ClientInterceptor {

private static final Metadata.Key<String> CLIENT_KEY = Metadata.Key.of("client",
Metadata.ASCII_STRING_MARSHALLER);

@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method,
CallOptions callOptions, Channel next) {
ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
return new SimpleForwardingClientCall<>(call) {
@Override
public void start(ClientCall.Listener<RespT> responseListener, Metadata headers) {
headers.put(CLIENT_KEY, "client-request");
super.start(responseListener, headers);
}

};
}

}

static class ServerHeaderInterceptor implements ServerInterceptor {

private static final Metadata.Key<String> SERVER_KEY = Metadata.Key.of("server",
Metadata.ASCII_STRING_MARSHALLER);

@Override
public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
SimpleForwardingServerCall<ReqT, RespT> serverCall = new SimpleForwardingServerCall<>(call) {
@Override
protected ServerCall<ReqT, RespT> delegate() {
return super.delegate();
}

@Override
public void sendHeaders(Metadata headers) {
headers.put(SERVER_KEY, "server-response");
super.sendHeaders(headers);
}

};
return next.startCall(serverCall, headers);
}

}

}

0 comments on commit cf8e8c7

Please sign in to comment.