From 598975635f319c5cc2a3ed493042139fb6a157bf Mon Sep 17 00:00:00 2001 From: Alexandru Vasile Date: Thu, 15 Feb 2024 18:07:24 +0200 Subject: [PATCH 1/6] rpc-servers: Add chainHead middleware to capture subscription IDs Signed-off-by: Alexandru Vasile --- substrate/client/rpc-servers/src/lib.rs | 20 +- .../rpc-servers/src/middleware/chain_head.rs | 205 ++++++++++++++++++ .../client/rpc-servers/src/middleware/mod.rs | 1 + 3 files changed, 224 insertions(+), 2 deletions(-) create mode 100644 substrate/client/rpc-servers/src/middleware/chain_head.rs diff --git a/substrate/client/rpc-servers/src/lib.rs b/substrate/client/rpc-servers/src/lib.rs index 29b34b2945b1..bbeee76fc015 100644 --- a/substrate/client/rpc-servers/src/lib.rs +++ b/substrate/client/rpc-servers/src/lib.rs @@ -22,7 +22,13 @@ pub mod middleware; -use std::{convert::Infallible, error::Error as StdError, net::SocketAddr, time::Duration}; +use std::{ + convert::Infallible, + error::Error as StdError, + net::SocketAddr, + sync::{Arc, Mutex}, + time::Duration, +}; use http::header::HeaderValue; use hyper::{ @@ -49,6 +55,8 @@ pub use jsonrpsee::core::{ }; pub use middleware::{MetricsLayer, RpcMetrics}; +use crate::middleware::chain_head::{ChainHeadLayer, ConnectionData}; + const MEGABYTE: u32 = 1024 * 1024; /// Type alias for the JSON-RPC server. @@ -142,6 +150,9 @@ pub async fn start_server( let make_service = make_service_fn(move |_conn: &AddrStream| { let cfg = cfg.clone(); + // Chain head data is per connection. + let chain_head_data = Arc::new(Mutex::new(ConnectionData::default())); + async move { let cfg = cfg.clone(); @@ -152,8 +163,13 @@ pub async fn start_server( let is_websocket = ws::is_upgrade_request(&req); let transport_label = if is_websocket { "ws" } else { "http" }; + // Order of the requests matter here, the metrics layer should be the first to not + // miss metrics. let metrics = metrics.map(|m| MetricsLayer::new(m, transport_label)); - let rpc_middleware = RpcServiceBuilder::new().option_layer(metrics.clone()); + let chain_head = ChainHeadLayer::new(chain_head_data.clone()); + + let rpc_middleware = + RpcServiceBuilder::new().option_layer(metrics.clone()).layer(chain_head); let mut svc = service_builder.set_rpc_middleware(rpc_middleware).build(methods, stop_handle); diff --git a/substrate/client/rpc-servers/src/middleware/chain_head.rs b/substrate/client/rpc-servers/src/middleware/chain_head.rs new file mode 100644 index 000000000000..0a58fc382552 --- /dev/null +++ b/substrate/client/rpc-servers/src/middleware/chain_head.rs @@ -0,0 +1,205 @@ +// This file is part of Substrate. + +// Copyright (C) Parity Technologies (UK) Ltd. +// SPDX-License-Identifier: GPL-3.0-or-later WITH Classpath-exception-2.0 + +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. + +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. + +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +//! RPC middleware to collect prometheus metrics on RPC calls. + +use std::{ + collections::HashSet, + future::Future, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, +}; + +use jsonrpsee::{ + server::middleware::rpc::RpcServiceT, + types::{Params, Request}, + MethodResponse, +}; +use pin_project::pin_project; + +/// The per connectin data needed to manage chainHead subscriptions. +#[derive(Default)] +pub struct ConnectionData { + /// Active `chainHeda_follow` subscriptions for this connection. + subscriptions: HashSet, +} + +/// Layer to allow the `chainHead` RPC methods to be called from a single connection. +#[derive(Clone)] +pub struct ChainHeadLayer { + connection_data: Arc>, +} + +impl ChainHeadLayer { + /// Create a new [`ChainHeadLayer`]. + pub fn new(connection_data: Arc>) -> Self { + Self { connection_data } + } +} + +impl tower::Layer for ChainHeadLayer { + type Service = ChainHeadMiddleware; + + fn layer(&self, inner: S) -> Self::Service { + ChainHeadMiddleware::new(inner, self.connection_data.clone()) + } +} + +/// Chain head middleware. +#[derive(Clone)] +pub struct ChainHeadMiddleware { + service: S, + connection_data: Arc>, +} + +impl ChainHeadMiddleware { + /// Create a new chain head middleware. + pub fn new(service: S, connection_data: Arc>) -> ChainHeadMiddleware { + ChainHeadMiddleware { service, connection_data } + } +} + +impl<'a, S> RpcServiceT<'a> for ChainHeadMiddleware +where + S: Send + Sync + RpcServiceT<'a>, +{ + type Future = ResponseFuture; + + fn call(&self, req: Request<'a>) -> Self::Future { + const CHAIN_HEAD_FOLLOW: &str = "chainHead_unstable_follow"; + const CHAIN_HEAD_CALL_METHODS: [&str; 7] = [ + "chainHead_unstable_body", + "chainHead_unstable_header", + "chainHead_unstable_call", + "chainHead_unstable_unpin", + "chainHead_unstable_continue", + "chainHead_unstable_stopOperation", + "chainHead_unstable_unfollow", + ]; + + let method_name = req.method_name(); + + // Intercept the subscription ID returned by the `chainHead_follow` method. + if method_name == CHAIN_HEAD_FOLLOW { + return ResponseFuture { + fut: self.service.call(req.clone()), + connection_data: Some(self.connection_data.clone()), + error: None, + } + } + + // Ensure the subscription ID of those methods corresponds to a subscription ID + // of this connection. + if CHAIN_HEAD_CALL_METHODS.contains(&method_name) { + let params = req.params(); + let follow_subscription = get_subscription_id(params); + + if let Some(follow_subscription) = follow_subscription { + if !self + .connection_data + .lock() + .unwrap() + .subscriptions + .contains(&follow_subscription) + { + log::debug!("{} called without a valid follow subscription", method_name); + + return ResponseFuture { + fut: self.service.call(req.clone()), + connection_data: None, + error: Some(MethodResponse::error( + req.id(), + jsonrpsee::types::error::ErrorObject::owned( + -32602, + "Invalid subscription ID", + None::<()>, + ), + )), + }; + } + } + } + + ResponseFuture { fut: self.service.call(req.clone()), connection_data: None, error: None } + } +} + +/// Extract the subscription ID from the provided parameters. +fn get_subscription_id<'a>(params: Params<'a>) -> Option { + // Support positional parameters. + if let Ok(follow_subscription) = params.sequence().next::() { + return Some(follow_subscription); + } + + let Ok(value) = params.parse::() else { + return None; + }; + + let serde_json::Value::Object(map) = value else { + return None; + }; + + if let Some(serde_json::Value::String(subscription_id)) = map.get("followSubscription") { + return Some(subscription_id.clone()); + } + + None +} + +/// Response future for metrics. +#[pin_project] +pub struct ResponseFuture { + #[pin] + fut: F, + connection_data: Option>>, + error: Option, +} + +impl<'a, F> std::fmt::Debug for ResponseFuture { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("ResponseFuture") + } +} + +impl> Future for ResponseFuture { + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + if let Some(err) = this.error.take() { + return Poll::Ready(err); + } + + let res = this.fut.poll(cx); + let connection_data = this.connection_data; + + match (&res, connection_data) { + (Poll::Ready(rp), Some(connection_data)) => + if rp.is_success() { + // if let Some(subscription_id) = rp.subscription_id() { + // connection_data.lock().subscriptions.insert(subscription_id); + // } + }, + _ => {}, + } + + res + } +} diff --git a/substrate/client/rpc-servers/src/middleware/mod.rs b/substrate/client/rpc-servers/src/middleware/mod.rs index 1c1930582441..3c578b0ea230 100644 --- a/substrate/client/rpc-servers/src/middleware/mod.rs +++ b/substrate/client/rpc-servers/src/middleware/mod.rs @@ -18,6 +18,7 @@ //! JSON-RPC specific middleware. +pub mod chain_head; pub mod metrics; pub use metrics::*; From 21615a8d7c6340177374580d764f43e4052c183c Mon Sep 17 00:00:00 2001 From: Alexandru Vasile Date: Thu, 15 Feb 2024 19:14:48 +0200 Subject: [PATCH 2/6] rpc-servers: Capture the method result as subscription ID Signed-off-by: Alexandru Vasile --- .../rpc-servers/src/middleware/chain_head.rs | 62 ++++++++++++++----- 1 file changed, 46 insertions(+), 16 deletions(-) diff --git a/substrate/client/rpc-servers/src/middleware/chain_head.rs b/substrate/client/rpc-servers/src/middleware/chain_head.rs index 0a58fc382552..236480dd8c03 100644 --- a/substrate/client/rpc-servers/src/middleware/chain_head.rs +++ b/substrate/client/rpc-servers/src/middleware/chain_head.rs @@ -83,12 +83,13 @@ where fn call(&self, req: Request<'a>) -> Self::Future { const CHAIN_HEAD_FOLLOW: &str = "chainHead_unstable_follow"; - const CHAIN_HEAD_CALL_METHODS: [&str; 7] = [ + const CHAIN_HEAD_CALL_METHODS: [&str; 8] = [ "chainHead_unstable_body", "chainHead_unstable_header", "chainHead_unstable_call", "chainHead_unstable_unpin", "chainHead_unstable_continue", + "chainHead_unstable_storage", "chainHead_unstable_stopOperation", "chainHead_unstable_unfollow", ]; @@ -97,6 +98,8 @@ where // Intercept the subscription ID returned by the `chainHead_follow` method. if method_name == CHAIN_HEAD_FOLLOW { + println!("Calling chainHEDA method"); + return ResponseFuture { fut: self.service.call(req.clone()), connection_data: Some(self.connection_data.clone()), @@ -107,8 +110,11 @@ where // Ensure the subscription ID of those methods corresponds to a subscription ID // of this connection. if CHAIN_HEAD_CALL_METHODS.contains(&method_name) { + println!("Calling other methods"); + let params = req.params(); let follow_subscription = get_subscription_id(params); + println!("follow_subscription: {:?}", follow_subscription); if let Some(follow_subscription) = follow_subscription { if !self @@ -141,27 +147,50 @@ where } /// Extract the subscription ID from the provided parameters. +/// +/// We make the assumption that all `chainHead` methods are given the +/// subscription ID as a first parameter. +/// +/// This method handles positional and named `camelCase` parameters. fn get_subscription_id<'a>(params: Params<'a>) -> Option { // Support positional parameters. if let Ok(follow_subscription) = params.sequence().next::() { - return Some(follow_subscription); + return Some(follow_subscription) } - let Ok(value) = params.parse::() else { - return None; - }; - - let serde_json::Value::Object(map) = value else { - return None; - }; + // Support named parameters. + let Ok(value) = params.parse::() else { return None }; + let serde_json::Value::Object(map) = value else { return None }; if let Some(serde_json::Value::String(subscription_id)) = map.get("followSubscription") { - return Some(subscription_id.clone()); + return Some(subscription_id.clone()) } None } +/// Extract the result of a jsonrpc object. +/// +/// The function extracts the `result` field from the JSON-RPC response. +/// +/// In this example, the result is `tfMQUZekzJLorGlR`. +/// ```ignore +/// "{"jsonrpc":"2.0","result":"tfMQUZekzJLorGlR","id":0}" +/// ``` +fn get_method_result(response: &MethodResponse) -> Option { + if response.is_error() { + return None + } + + let result = response.as_result(); + let Ok(value) = serde_json::from_str(result) else { return None }; + + let serde_json::Value::Object(map) = value else { return None }; + let Some(serde_json::Value::String(res)) = map.get("result") else { return None }; + + Some(res.clone()) +} + /// Response future for metrics. #[pin_project] pub struct ResponseFuture { @@ -191,12 +220,13 @@ impl> Future for ResponseFuture { let connection_data = this.connection_data; match (&res, connection_data) { - (Poll::Ready(rp), Some(connection_data)) => - if rp.is_success() { - // if let Some(subscription_id) = rp.subscription_id() { - // connection_data.lock().subscriptions.insert(subscription_id); - // } - }, + (Poll::Ready(rp), Some(connection_data)) => { + println!("Response sub: {:?}", rp.to_result()); + if let Some(subscription_id) = get_method_result(rp) { + connection_data.lock().unwrap().subscriptions.insert(subscription_id); + } + }, + _ => {}, } From 0a5e2dc063c9ebc97bcd09c903cd0553fa0735cb Mon Sep 17 00:00:00 2001 From: Alexandru Vasile Date: Thu, 15 Feb 2024 19:34:53 +0200 Subject: [PATCH 3/6] rpc-servers: Cleanup the ResponseFuture Signed-off-by: Alexandru Vasile --- .../rpc-servers/src/middleware/chain_head.rs | 79 +++++++++++-------- 1 file changed, 47 insertions(+), 32 deletions(-) diff --git a/substrate/client/rpc-servers/src/middleware/chain_head.rs b/substrate/client/rpc-servers/src/middleware/chain_head.rs index 236480dd8c03..5453d98eae35 100644 --- a/substrate/client/rpc-servers/src/middleware/chain_head.rs +++ b/substrate/client/rpc-servers/src/middleware/chain_head.rs @@ -16,7 +16,7 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . -//! RPC middleware to collect prometheus metrics on RPC calls. +//! RPC middleware to ensure chainHead methods are called from a single connection. use std::{ collections::HashSet, @@ -100,10 +100,9 @@ where if method_name == CHAIN_HEAD_FOLLOW { println!("Calling chainHEDA method"); - return ResponseFuture { + return ResponseFuture::Register { fut: self.service.call(req.clone()), - connection_data: Some(self.connection_data.clone()), - error: None, + connection_data: self.connection_data.clone(), } } @@ -126,10 +125,8 @@ where { log::debug!("{} called without a valid follow subscription", method_name); - return ResponseFuture { - fut: self.service.call(req.clone()), - connection_data: None, - error: Some(MethodResponse::error( + return ResponseFuture::Ready { + response: Some(MethodResponse::error( req.id(), jsonrpsee::types::error::ErrorObject::owned( -32602, @@ -142,7 +139,7 @@ where } } - ResponseFuture { fut: self.service.call(req.clone()), connection_data: None, error: None } + ResponseFuture::Forward { fut: self.service.call(req.clone()) } } } @@ -191,13 +188,35 @@ fn get_method_result(response: &MethodResponse) -> Option { Some(res.clone()) } -/// Response future for metrics. -#[pin_project] -pub struct ResponseFuture { - #[pin] - fut: F, - connection_data: Option>>, - error: Option, +/// Response future for chainHead middleware. +#[pin_project(project = ResponseFutureProj)] +pub enum ResponseFuture { + /// The response is propagated immediately without calling other layers. + /// + /// This is used in case of an error. + Ready { + /// The response provided to the client directly. + /// + /// This is `Option` to consume the value and return a `MethodResponse` + /// from the projected structure. + response: Option, + }, + + /// Forward the call to another layer. + Forward { + /// The future response value. + #[pin] + fut: F, + }, + + /// Forward the call to another layer and store the subscription ID of the response. + Register { + /// The future response value. + #[pin] + fut: F, + /// Connection data that captures the subscription ID. + connection_data: Arc>, + }, } impl<'a, F> std::fmt::Debug for ResponseFuture { @@ -212,24 +231,20 @@ impl> Future for ResponseFuture { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); - if let Some(err) = this.error.take() { - return Poll::Ready(err); - } - - let res = this.fut.poll(cx); - let connection_data = this.connection_data; - - match (&res, connection_data) { - (Poll::Ready(rp), Some(connection_data)) => { - println!("Response sub: {:?}", rp.to_result()); - if let Some(subscription_id) = get_method_result(rp) { - connection_data.lock().unwrap().subscriptions.insert(subscription_id); + match this { + ResponseFutureProj::Ready { response } => + Poll::Ready(response.take().expect("Value is set; qed")), + ResponseFutureProj::Forward { fut } => fut.poll(cx), + ResponseFutureProj::Register { fut, connection_data } => { + let res = fut.poll(cx); + if let Poll::Ready(response) = &res { + if let Some(subscription_id) = get_method_result(response) { + println!("SSub id {:?}", subscription_id); + connection_data.lock().unwrap().subscriptions.insert(subscription_id); + } } + res }, - - _ => {}, } - - res } } From 5b6063f5066291efe10038cbe9c2704e5046f6b4 Mon Sep 17 00:00:00 2001 From: Alexandru Vasile Date: Thu, 15 Feb 2024 19:41:29 +0200 Subject: [PATCH 4/6] rpc-servers: Use parkinglot instead of std mutex Signed-off-by: Alexandru Vasile --- .../rpc-servers/src/middleware/chain_head.rs | 21 ++++--------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/substrate/client/rpc-servers/src/middleware/chain_head.rs b/substrate/client/rpc-servers/src/middleware/chain_head.rs index 5453d98eae35..604aaad35176 100644 --- a/substrate/client/rpc-servers/src/middleware/chain_head.rs +++ b/substrate/client/rpc-servers/src/middleware/chain_head.rs @@ -22,7 +22,7 @@ use std::{ collections::HashSet, future::Future, pin::Pin, - sync::{Arc, Mutex}, + sync::Arc, task::{Context, Poll}, }; @@ -31,6 +31,7 @@ use jsonrpsee::{ types::{Params, Request}, MethodResponse, }; +use parking_lot::Mutex; use pin_project::pin_project; /// The per connectin data needed to manage chainHead subscriptions. @@ -98,8 +99,6 @@ where // Intercept the subscription ID returned by the `chainHead_follow` method. if method_name == CHAIN_HEAD_FOLLOW { - println!("Calling chainHEDA method"); - return ResponseFuture::Register { fut: self.service.call(req.clone()), connection_data: self.connection_data.clone(), @@ -109,22 +108,11 @@ where // Ensure the subscription ID of those methods corresponds to a subscription ID // of this connection. if CHAIN_HEAD_CALL_METHODS.contains(&method_name) { - println!("Calling other methods"); - let params = req.params(); let follow_subscription = get_subscription_id(params); - println!("follow_subscription: {:?}", follow_subscription); if let Some(follow_subscription) = follow_subscription { - if !self - .connection_data - .lock() - .unwrap() - .subscriptions - .contains(&follow_subscription) - { - log::debug!("{} called without a valid follow subscription", method_name); - + if !self.connection_data.lock().subscriptions.contains(&follow_subscription) { return ResponseFuture::Ready { response: Some(MethodResponse::error( req.id(), @@ -239,8 +227,7 @@ impl> Future for ResponseFuture { let res = fut.poll(cx); if let Poll::Ready(response) = &res { if let Some(subscription_id) = get_method_result(response) { - println!("SSub id {:?}", subscription_id); - connection_data.lock().unwrap().subscriptions.insert(subscription_id); + connection_data.lock().subscriptions.insert(subscription_id); } } res From cb7b67ad06791571faa0b8f8c8ee364ef0cd2049 Mon Sep 17 00:00:00 2001 From: Alexandru Vasile <60601340+lexnv@users.noreply.github.com> Date: Thu, 15 Feb 2024 19:44:37 +0200 Subject: [PATCH 5/6] Update substrate/client/rpc-servers/src/middleware/chain_head.rs --- substrate/client/rpc-servers/src/middleware/chain_head.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/substrate/client/rpc-servers/src/middleware/chain_head.rs b/substrate/client/rpc-servers/src/middleware/chain_head.rs index 604aaad35176..a51b323944fb 100644 --- a/substrate/client/rpc-servers/src/middleware/chain_head.rs +++ b/substrate/client/rpc-servers/src/middleware/chain_head.rs @@ -37,7 +37,7 @@ use pin_project::pin_project; /// The per connectin data needed to manage chainHead subscriptions. #[derive(Default)] pub struct ConnectionData { - /// Active `chainHeda_follow` subscriptions for this connection. + /// Active `chainHead_follow` subscriptions for this connection. subscriptions: HashSet, } From 9ae9e9bed5635ea1be6f92b130af3293a2c83923 Mon Sep 17 00:00:00 2001 From: Alexandru Vasile Date: Thu, 15 Feb 2024 19:53:28 +0200 Subject: [PATCH 6/6] rpc-servers: Add parking lot dependency Signed-off-by: Alexandru Vasile --- Cargo.lock | 1 + substrate/client/rpc-servers/Cargo.toml | 1 + 2 files changed, 2 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index ac4725dd483e..c9c6e81ea87b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -16502,6 +16502,7 @@ dependencies = [ "hyper", "jsonrpsee", "log", + "parking_lot 0.12.1", "pin-project", "serde_json", "substrate-prometheus-endpoint", diff --git a/substrate/client/rpc-servers/Cargo.toml b/substrate/client/rpc-servers/Cargo.toml index c8935c36afd9..2bd7825c7b62 100644 --- a/substrate/client/rpc-servers/Cargo.toml +++ b/substrate/client/rpc-servers/Cargo.toml @@ -27,3 +27,4 @@ http = "0.2.8" hyper = "0.14.27" futures = "0.3.29" pin-project = "1.1.3" +parking_lot = "0.12.1"