From f3aaaa598f485903c630e015b21a942ca3af9557 Mon Sep 17 00:00:00 2001 From: "Taro L. Saito" Date: Thu, 30 May 2024 22:16:13 -0700 Subject: [PATCH] http (fix, breaking): RPCContext.current.getThreadLocal interface change to avoid unsafe type cast (#3548) - Breaking change: `RPCContext.current.getThreadLocal[A](key: String): A` -> `RPCContext.current.getThreadLocal(key: String): Any` to avoid type cast error - Also, fixed a bug of getting the previous thread-local values --- .../http/finagle/FinagleRPCContext.scala | 6 ++++- .../http/finagle/ThreadLocalStorageTest.scala | 2 +- .../airframe/http/grpc/GrpcContext.scala | 16 ++++------- .../airframe/http/grpc/example/DemoApi.scala | 2 +- .../airframe/http/netty/NettyBackend.scala | 12 +++------ .../airframe/http/netty/NettyRPCContext.scala | 13 +++++---- .../http/netty/NettyBackendTest.scala | 12 ++++----- .../http/netty/NettyLoggingTest.scala | 21 ++++++++++++--- .../http/internal/LocalRPCContext.scala | 2 +- .../airframe/http/internal/TLSSupport.scala | 27 +++++++++++++++++++ .../wvlet/airframe/http/RPCContext.scala | 20 +++++++++++--- docs/airframe-http.md | 6 ++--- docs/airframe-rpc.md | 4 +-- 13 files changed, 94 insertions(+), 49 deletions(-) create mode 100644 airframe-http/.jvm/src/main/scala/wvlet/airframe/http/internal/TLSSupport.scala diff --git a/airframe-http-finagle/src/main/scala/wvlet/airframe/http/finagle/FinagleRPCContext.scala b/airframe-http-finagle/src/main/scala/wvlet/airframe/http/finagle/FinagleRPCContext.scala index 1b68a54596..ceedcb61e2 100644 --- a/airframe-http-finagle/src/main/scala/wvlet/airframe/http/finagle/FinagleRPCContext.scala +++ b/airframe-http-finagle/src/main/scala/wvlet/airframe/http/finagle/FinagleRPCContext.scala @@ -24,10 +24,14 @@ case class FinagleRPCContext(request: Request) extends RPCContext { FinagleBackend.setThreadLocal(key, value) } - override def getThreadLocal[A](key: String): Option[A] = { + override def getThreadLocal(key: String): Option[Any] = { FinagleBackend.getThreadLocal(key) } + override def getThreadLocalUnsafe[A](key: String): Option[A] = { + getThreadLocal(key).map(_.asInstanceOf[A]) + } + override def httpRequest: HttpMessage.Request = { request.toHttpRequest } diff --git a/airframe-http-finagle/src/test/scala/wvlet/airframe/http/finagle/ThreadLocalStorageTest.scala b/airframe-http-finagle/src/test/scala/wvlet/airframe/http/finagle/ThreadLocalStorageTest.scala index eb7d023ffd..ef4a0d702d 100644 --- a/airframe-http-finagle/src/test/scala/wvlet/airframe/http/finagle/ThreadLocalStorageTest.scala +++ b/airframe-http-finagle/src/test/scala/wvlet/airframe/http/finagle/ThreadLocalStorageTest.scala @@ -40,7 +40,7 @@ class ThreadLocalStorageTest extends AirSpec { @Endpoint(path = "/rpc-context") def rpcContext: String = { - RPCContext.current.getThreadLocal[String]("client_id").getOrElse("unknown") + RPCContext.current.getThreadLocal("client_id").map(_.toString).getOrElse("unknown") } @Endpoint(path = "/rpc-header") diff --git a/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/GrpcContext.scala b/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/GrpcContext.scala index c145b1e36b..de2791a93b 100644 --- a/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/GrpcContext.scala +++ b/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/GrpcContext.scala @@ -14,6 +14,7 @@ package wvlet.airframe.http.grpc import io.grpc.* +import wvlet.airframe.http.internal.TLSSupport import wvlet.airframe.http.{Http, HttpMessage, RPCContext, RPCEncoding} import wvlet.log.LogSupport @@ -57,16 +58,9 @@ case class GrpcContext( metadata: Metadata, descriptor: MethodDescriptor[_, _] ) extends RPCContext + with TLSSupport with LogSupport { - // Grpc doesn't provide a mutable thread-local stage, so create our own TLS here. - private lazy val tls = - ThreadLocal.withInitial[collection.mutable.Map[String, Any]](() => mutable.Map.empty[String, Any]) - - private def storage: collection.mutable.Map[String, Any] = { - tls.get() - } - // Return the accept header def accept: String = metadata.accept def encoding: RPCEncoding = accept match { @@ -79,11 +73,11 @@ case class GrpcContext( } override def setThreadLocal[A](key: String, value: A): Unit = { - storage.put(key, value) + setTLS(key, value) } - override def getThreadLocal[A](key: String): Option[A] = { - storage.get(key).asInstanceOf[Option[A]] + override def getThreadLocal(key: String): Option[Any] = { + getTLS(key) } override def httpRequest: HttpMessage.Request = { diff --git a/airframe-http-grpc/src/test/scala/wvlet/airframe/http/grpc/example/DemoApi.scala b/airframe-http-grpc/src/test/scala/wvlet/airframe/http/grpc/example/DemoApi.scala index eefaca6b2c..07fe75c7ae 100644 --- a/airframe-http-grpc/src/test/scala/wvlet/airframe/http/grpc/example/DemoApi.scala +++ b/airframe-http-grpc/src/test/scala/wvlet/airframe/http/grpc/example/DemoApi.scala @@ -47,7 +47,7 @@ trait DemoApi extends LogSupport { def getRPCContext: Option[String] = { val ctx = RPCContext.current - ctx.getThreadLocal[String]("client_id") + ctx.getThreadLocal("client_id").map(_.toString) } def getRequest: Request = { diff --git a/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyBackend.scala b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyBackend.scala index a741a89416..351eeffbe6 100644 --- a/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyBackend.scala +++ b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyBackend.scala @@ -15,6 +15,7 @@ package wvlet.airframe.http.netty import wvlet.airframe.http.HttpMessage.{Request, Response} import wvlet.airframe.http.* +import wvlet.airframe.http.internal.TLSSupport import wvlet.airframe.rx.Rx import wvlet.log.LogSupport @@ -22,7 +23,7 @@ import scala.collection.mutable import scala.concurrent.{Await, ExecutionContext, Future, Promise} import scala.util.{Failure, Success} -object NettyBackend extends HttpBackend[Request, Response, Rx] with LogSupport { self => +object NettyBackend extends HttpBackend[Request, Response, Rx] with TLSSupport with LogSupport { self => private val rxBackend = new RxNettyBackend override protected implicit val httpRequestAdapter: HttpRequestAdapter[Request] = @@ -89,21 +90,16 @@ object NettyBackend extends HttpBackend[Request, Response, Rx] with LogSupport { f.toRx.map(body) } - private lazy val tls = - ThreadLocal.withInitial[collection.mutable.Map[String, Any]](() => mutable.Map.empty[String, Any]) - - private def storage: collection.mutable.Map[String, Any] = tls.get() - override def withThreadLocalStore(request: => Rx[Response]): Rx[Response] = { // request } override def setThreadLocal[A](key: String, value: A): Unit = { - storage.put(key, value) + setTLS(key, value) } override def getThreadLocal[A](key: String): Option[A] = { - storage.get(key).asInstanceOf[Option[A]] + getTLS(key).map(_.asInstanceOf[A]) } } diff --git a/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRPCContext.scala b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRPCContext.scala index 8a9835f55e..bbe6f7bb07 100644 --- a/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRPCContext.scala +++ b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRPCContext.scala @@ -15,12 +15,11 @@ package wvlet.airframe.http.netty import wvlet.airframe.http.HttpMessage.Request import wvlet.airframe.http.RPCContext +import wvlet.airframe.http.internal.TLSSupport -class NettyRPCContext(val httpRequest: Request) extends RPCContext { - override def setThreadLocal[A](key: String, value: A): Unit = { - NettyBackend.setThreadLocal(key, value) - } - override def getThreadLocal[A](key: String): Option[A] = { - NettyBackend.getThreadLocal(key) - } +import scala.collection.mutable + +class NettyRPCContext(val httpRequest: Request) extends RPCContext with TLSSupport { + override def setThreadLocal[A](key: String, value: A): Unit = setTLS(key, value) + override def getThreadLocal(key: String): Option[Any] = getTLS(key) } diff --git a/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/NettyBackendTest.scala b/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/NettyBackendTest.scala index 44fec41b77..ce71118ada 100644 --- a/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/NettyBackendTest.scala +++ b/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/NettyBackendTest.scala @@ -22,25 +22,25 @@ class NettyBackendTest extends AirSpec { val key = ULID.newULIDString test("must be None by default") { - NettyBackend.getThreadLocal[Int](key) shouldBe None + NettyBackend.getThreadLocal(key) shouldBe None } test("store different content for each thread") { - NettyBackend.setThreadLocal[Int](key, 123) + NettyBackend.setThreadLocal(key, 123) var valueInThread: Option[Int] = None val t = new Thread { override def run(): Unit = { - NettyBackend.getThreadLocal[Int](key) shouldBe None - NettyBackend.setThreadLocal[Int](key, 456) - valueInThread = NettyBackend.getThreadLocal[Int](key) + NettyBackend.getThreadLocal(key) shouldBe None + NettyBackend.setThreadLocal(key, 456) + valueInThread = NettyBackend.getThreadLocal(key) } } t.start() t.join() - NettyBackend.getThreadLocal[Int](key) shouldBe Some(123) + NettyBackend.getThreadLocal(key) shouldBe Some(123) valueInThread shouldBe Some(456) } } diff --git a/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/NettyLoggingTest.scala b/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/NettyLoggingTest.scala index 4b276f3c3c..576fbfef20 100644 --- a/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/NettyLoggingTest.scala +++ b/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/NettyLoggingTest.scala @@ -26,9 +26,14 @@ object NettyLoggingTest extends AirSpec { @RPC class MyRPC extends LogSupport { + private var requestCount = 0 + def hello(): Unit = { - RPCContext.current.setThreadLocal("user", "xxxx_yyyy") - debug("hello rpc") + if (requestCount == 0) { + RPCContext.current.setThreadLocal("user", "xxxx_yyyy") + } + requestCount += 1 + trace("hello rpc") } } @@ -46,7 +51,7 @@ object NettyLoggingTest extends AirSpec { .withName("log-test-server") .withExtraLogEntries { () => val m = ListMap.newBuilder[String, Any] - RPCContext.current.getThreadLocal[String]("user").foreach { v => + RPCContext.current.getThreadLocal("user").foreach { v => m += "user" -> v } m += ("custom_log_entry" -> "test") @@ -67,12 +72,20 @@ object NettyLoggingTest extends AirSpec { test("add server custom log") { (syncClient: SyncClient) => syncClient.send(Http.POST("/wvlet.airframe.http.netty.NettyLoggingTest.MyRPC/hello")) - val logEntry = serverLogger.getLogs.head + val logs = serverLogger.getLogs + val logEntry = logs(0) debug(logEntry) logEntry shouldContain ("server_name" -> "log-test-server") logEntry shouldContain ("custom_log_entry" -> "test") logEntry shouldContain ("user" -> "xxxx_yyyy") + test("do not set TLS in the second request") { + syncClient.send(Http.POST("/wvlet.airframe.http.netty.NettyLoggingTest.MyRPC/hello")) + val l = serverLogger.getLogs(1) + debug(l) + l shouldNotContain ("user" -> "xxxx_yyyy") + } + test("add client custom log") { val clientLogEntry = clientLogger.getLogs.head debug(clientLogEntry) diff --git a/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/internal/LocalRPCContext.scala b/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/internal/LocalRPCContext.scala index ef1769e0f1..d5dd1fbf8c 100644 --- a/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/internal/LocalRPCContext.scala +++ b/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/internal/LocalRPCContext.scala @@ -13,7 +13,7 @@ */ package wvlet.airframe.http.internal -import wvlet.airframe.http.{RPCContext, EmptyRPCContext} +import wvlet.airframe.http.{EmptyRPCContext, RPCContext} object LocalRPCContext { private val localContext = new ThreadLocal[RPCContext]() diff --git a/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/internal/TLSSupport.scala b/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/internal/TLSSupport.scala new file mode 100644 index 0000000000..31f420be79 --- /dev/null +++ b/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/internal/TLSSupport.scala @@ -0,0 +1,27 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package wvlet.airframe.http.internal + +import scala.collection.mutable + +/** + * Thread-local storage support + */ +private[http] trait TLSSupport { + private lazy val tls = ThreadLocal.withInitial[mutable.Map[String, Any]](() => mutable.Map.empty[String, Any]) + private def tlsStorage(): mutable.Map[String, Any] = tls.get() + + def setTLS(key: String, value: Any): Unit = tlsStorage().put(key, value) + def getTLS(key: String): Option[Any] = tlsStorage().get(key) +} diff --git a/airframe-http/src/main/scala/wvlet/airframe/http/RPCContext.scala b/airframe-http/src/main/scala/wvlet/airframe/http/RPCContext.scala index fb0a2de72f..cc1247c72e 100644 --- a/airframe-http/src/main/scala/wvlet/airframe/http/RPCContext.scala +++ b/airframe-http/src/main/scala/wvlet/airframe/http/RPCContext.scala @@ -37,7 +37,10 @@ trait RPCContext { def httpRequest: HttpMessage.Request def rpcCallContext: Option[RPCCallContext] = { - getThreadLocal[RPCCallContext](HttpBackend.TLS_KEY_RPC) + getThreadLocal(HttpBackend.TLS_KEY_RPC) match { + case Some(c: RPCCallContext) => Some(c) + case _ => None + } } /** @@ -52,10 +55,19 @@ trait RPCContext { * Get a thread-local variable that is available only within the request scope. The type must be specified * explicitly. * @param key - * @tparam A * @return */ - def getThreadLocal[A](key: String): Option[A] + @deprecated("Use getThreadLocal(key: String): Any instead", "24.5.0") + def getThreadLocalUnsafe[A](key: String): Option[A] = { + getThreadLocal(key).map(_.asInstanceOf[A]) + } + + /** + * Get a thread-local variable that is available only within the request scope. + * @param key + * @return + */ + def getThreadLocal(key: String): Option[Any] } /** @@ -65,7 +77,7 @@ object EmptyRPCContext extends RPCContext { override def setThreadLocal[A](key: String, value: A): Unit = { // no-op } - override def getThreadLocal[A](key: String): Option[A] = { + override def getThreadLocal(key: String): Option[Any] = { // no-op None } diff --git a/docs/airframe-http.md b/docs/airframe-http.md index d1f5b00e34..45a0a66e86 100644 --- a/docs/airframe-http.md +++ b/docs/airframe-http.md @@ -150,14 +150,14 @@ val server = Netty.server // Add a custom log entry m += "application_version" -> "1.0" // Add a thread-local parameter to the log - RPCContext.current.getThreadLocal[String]("user_id").map { uid => + RPCContext.current.getThreadLocal("user_id").map { uid => m += "user_id" -> uid } m.result } // [optional] Disable server-side logging (log/http_server.json) .noLogging - // Add a custom MessageCodec mapping + // [optional] Add a custom MessageCodec mapping .withCustomCodec{ case s: Surface.of[MyClass] => ... } server.start { server => @@ -372,7 +372,7 @@ object AuthLogFilter extends RxHttpFilter with LogSupport { def apply(request: Request, next: RxHttpEndpoint): Rx[Response] = { next(request).map { response => // Read the thread-local parameter set in the context(request) - RPCContext.current.getThreadLocal[String]("user_id").map { uid => + RPCContext.current.getThreadLocal("user_id").map { uid => info(s"user_id: ${uid}") } response diff --git a/docs/airframe-rpc.md b/docs/airframe-rpc.md index e91b243f7c..855004a55a 100644 --- a/docs/airframe-rpc.md +++ b/docs/airframe-rpc.md @@ -447,7 +447,7 @@ String "100" will be translated into an Int value `100` automatically. ### RPCContext -Since Airframe 22.8.0, airframe-rpc introduced `RPCContext` for reading and writing the thread-local storage, and referencing the original HTTP request: +Since Airframe 22.8.0, airframe-rpc introduced `RPCContext.current` for reading and writing the thread-local storage, and referencing the original HTTP request: ```scala import wvlet.airframe.http._ @@ -456,7 +456,7 @@ import wvlet.airframe.http._ trait MyAPI { def hello: String = { // Read the thread-local storage - val userName = RPCContext.current.getThreadLocal[String]("context_user") + val userName = RPCContext.current.getThreadLocal("context_user") s"Hello ${userName}" }