From e73dff17e9a9909743fd57bbd9d6769789a455e1 Mon Sep 17 00:00:00 2001 From: Mark Drobnak Date: Mon, 13 Jul 2020 14:03:38 -0400 Subject: [PATCH] feat: Route notifications to autopush connection servers (#167) * Generate a message ID for each notification * Use a default TTL of 0 in NotificationHeaders * Impl Serialize for Notification and NotificationHeaders * Add a Router trait and stub WebPushRouter * Use async_trait in Router and add RouterResponse * Implement direct-route-to-node happy-path for WebPushRouter * Replace serde Serialize notification impls with more direct impls * Add DynamoStorage::remove_node_id and use in webpush router The Router async functions may not return Send futures, so the async_trait annotation is modified to not require Send. * Add DynamoStorage::store_message and use in webpush router * Add a UserNotFound error kind to the common code and use in get_uaid * Try to notify the node after storing a message, if available * Remove the `updates.client.deleted` metric Based on metrics spreadsheet discussion. * Add RouterType enum for easy matching on the router type * Take notification by reference when routing * Add Routers extractor and use in webpush_route * Add endpoint_url setting and remove unused database settings * Remove settings banner, as actix already logs the data as INFO * Fix the Location header TODO in webpush router * Add RouterError * Fix incorrect usage of UserNotFound error (remove it) * Return the router response from the HTTP handler * Fix missing Clone impl (rebase error) * Fix serialization of WebPush notifications Notifications were serialized without required fields (due to `skip_serializing` in the autopush-common code). While this is correct when giving the notification to the UA, the connection server needs these fields. We now perform this serialization separately from the autopush-common Notification serialization. * Add debug and trace level logging to the WebPush router * Ignore local configs * Fix wrong content encoding key in header map WebPush expects "encoding" instead of "content_encoding" * Return an error if there is no TTL value * Fix NotificationHeader tests after requiring TTL * Update errnos for RouterError and NoTTL Closes #161 --- .gitignore | 3 + Cargo.lock | 8 +- autoendpoint/Cargo.toml | 2 + autoendpoint/src/error.rs | 35 ++- autoendpoint/src/main.rs | 1 - autoendpoint/src/server/extractors/mod.rs | 1 + .../src/server/extractors/notification.rs | 99 +++++++- .../server/extractors/notification_headers.rs | 66 ++++-- autoendpoint/src/server/extractors/routers.rs | 48 ++++ .../src/server/extractors/subscription.rs | 7 +- autoendpoint/src/server/extractors/user.rs | 46 +++- autoendpoint/src/server/headers/vapid.rs | 3 +- autoendpoint/src/server/mod.rs | 5 + autoendpoint/src/server/routers/mod.rs | 63 ++++++ autoendpoint/src/server/routers/webpush.rs | 212 ++++++++++++++++++ autoendpoint/src/server/routes/webpush.rs | 13 +- autoendpoint/src/settings.rs | 18 +- autopush-common/src/db/mod.rs | 61 ++++- 18 files changed, 626 insertions(+), 65 deletions(-) create mode 100644 autoendpoint/src/server/extractors/routers.rs create mode 100644 autoendpoint/src/server/routers/mod.rs create mode 100644 autoendpoint/src/server/routers/webpush.rs diff --git a/.gitignore b/.gitignore index 9c2ad2c0..d46fed15 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,6 @@ target requirements.txt test-requirements.txt venv + +# Local configs +*.local.toml diff --git a/Cargo.lock b/Cargo.lock index 3157fb85..13fd0358 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -299,7 +299,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "async-trait" -version = "0.1.32" +version = "0.1.36" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ "proc-macro2 1.0.9 (registry+https://github.com/rust-lang/crates.io-index)", @@ -335,6 +335,7 @@ dependencies = [ "actix-http 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", "actix-rt 1.1.1 (registry+https://github.com/rust-lang/crates.io-index)", "actix-web 2.0.0 (registry+https://github.com/rust-lang/crates.io-index)", + "async-trait 0.1.36 (registry+https://github.com/rust-lang/crates.io-index)", "autopush_common 1.0.0", "backtrace 0.3.44 (registry+https://github.com/rust-lang/crates.io-index)", "base64 0.12.1 (registry+https://github.com/rust-lang/crates.io-index)", @@ -348,6 +349,7 @@ dependencies = [ "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "openssl 0.10.29 (registry+https://github.com/rust-lang/crates.io-index)", "regex 1.3.9 (registry+https://github.com/rust-lang/crates.io-index)", + "reqwest 0.10.6 (registry+https://github.com/rust-lang/crates.io-index)", "sentry 0.18.1 (registry+https://github.com/rust-lang/crates.io-index)", "serde 1.0.111 (registry+https://github.com/rust-lang/crates.io-index)", "serde_json 1.0.53 (registry+https://github.com/rust-lang/crates.io-index)", @@ -3347,7 +3349,7 @@ name = "trust-dns-proto" version = "0.18.0-alpha.2" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ - "async-trait 0.1.32 (registry+https://github.com/rust-lang/crates.io-index)", + "async-trait 0.1.36 (registry+https://github.com/rust-lang/crates.io-index)", "enum-as-inner 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)", "failure 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)", "futures 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", @@ -3794,7 +3796,7 @@ dependencies = [ "checksum arrayref 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "a4c527152e37cf757a3f78aae5a06fbeefdb07ccc535c980a3208ee3060dd544" "checksum arrayvec 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)" = "cd9fd44efafa8690358b7408d253adf110036b88f55672a933f01d616ad9b1b9" "checksum arrayvec 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "cff77d8686867eceff3105329d4698d96c2391c176d5d03adc90c7389162b5b8" -"checksum async-trait 0.1.32 (registry+https://github.com/rust-lang/crates.io-index)" = "0eb7f9ad01405feb3c1dac82463038945cf88eea4569acaf3ad662233496dd96" +"checksum async-trait 0.1.36 (registry+https://github.com/rust-lang/crates.io-index)" = "a265e3abeffdce30b2e26b7a11b222fe37c6067404001b434101457d0385eb92" "checksum atty 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)" = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" "checksum autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)" = "1d49d90015b3c36167a20fe2810c5cd875ad504b39cff3d4eae7977e6b7c1cb2" "checksum autocfg 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "f8aac770f1885fd7e387acedd76065302551364496e46b3dd00860b2f8359b9d" diff --git a/autoendpoint/Cargo.toml b/autoendpoint/Cargo.toml index d3f4491f..33865780 100644 --- a/autoendpoint/Cargo.toml +++ b/autoendpoint/Cargo.toml @@ -9,6 +9,7 @@ actix-http = "1.0" actix-web = "2.0" actix-rt = "1.0" actix-cors = "0.2.0" +async-trait = "0.1.36" autopush_common = { path = "../autopush-common" } backtrace = "0.3" base64 = "0.12.1" @@ -22,6 +23,7 @@ jsonwebtoken = "7.1.1" lazy_static = "1.4.0" openssl = "0.10" regex = "1.3" +reqwest = "0.10.6" sentry = { version = "0.18", features = ["with_curl_transport"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" diff --git a/autoendpoint/src/error.rs b/autoendpoint/src/error.rs index 16835273..e258e3d6 100644 --- a/autoendpoint/src/error.rs +++ b/autoendpoint/src/error.rs @@ -1,6 +1,6 @@ //! Error types and transformations -use crate::server::VapidError; +use crate::server::{RouterError, VapidError}; use actix_web::{ dev::{HttpResponseBuilder, ServiceResponse}, error::{PayloadError, ResponseError}, @@ -59,6 +59,9 @@ pub enum ApiErrorKind { #[error(transparent)] VapidError(#[from] VapidError), + #[error(transparent)] + Router(#[from] RouterError), + #[error(transparent)] Uuid(#[from] uuid::Error), @@ -84,13 +87,16 @@ pub enum ApiErrorKind { #[error("{0}")] InvalidEncryption(String), - #[error("Data payload must be smaller than {} bytes", .0)] + #[error("Data payload must be smaller than {0} bytes")] PayloadTooLarge(usize), /// Used if the API version given is not v1 or v2 #[error("Invalid API version")] InvalidApiVersion, + #[error("Missing TTL value")] + NoTTL, + #[error("{0}")] Internal(String), } @@ -100,11 +106,13 @@ impl ApiErrorKind { pub fn status(&self) -> StatusCode { match self { ApiErrorKind::PayloadError(e) => e.status_code(), + ApiErrorKind::Router(e) => e.status(), ApiErrorKind::Validation(_) | ApiErrorKind::InvalidEncryption(_) | ApiErrorKind::TokenHashValidation(_) - | ApiErrorKind::Uuid(_) => StatusCode::BAD_REQUEST, + | ApiErrorKind::Uuid(_) + | ApiErrorKind::NoTTL => StatusCode::BAD_REQUEST, ApiErrorKind::NoUser | ApiErrorKind::NoSubscription => StatusCode::GONE, @@ -124,16 +132,27 @@ impl ApiErrorKind { /// Get the associated error number pub fn errno(&self) -> Option { match self { - ApiErrorKind::InvalidEncryption(_) => Some(110), - ApiErrorKind::VapidError(_) - | ApiErrorKind::TokenHashValidation(_) - | ApiErrorKind::Jwt(_) => Some(109), + ApiErrorKind::Router(e) => Some(e.errno()), + ApiErrorKind::InvalidToken => Some(102), + ApiErrorKind::NoUser => Some(103), - ApiErrorKind::NoSubscription => Some(106), + ApiErrorKind::PayloadError(PayloadError::Overflow) | ApiErrorKind::PayloadTooLarge(_) => Some(104), + + ApiErrorKind::NoSubscription => Some(106), + + ApiErrorKind::VapidError(_) + | ApiErrorKind::TokenHashValidation(_) + | ApiErrorKind::Jwt(_) => Some(109), + + ApiErrorKind::InvalidEncryption(_) => Some(110), + + ApiErrorKind::NoTTL => Some(111), + ApiErrorKind::Internal(_) => Some(999), + _ => None, } } diff --git a/autoendpoint/src/main.rs b/autoendpoint/src/main.rs index 84d6fc62..6e281d6e 100644 --- a/autoendpoint/src/main.rs +++ b/autoendpoint/src/main.rs @@ -41,7 +41,6 @@ async fn main() -> Result<(), Box> { let _sentry_guard = configure_sentry(); // Run server... - debug!("{}", settings.banner()); let server = server::Server::with_settings(settings).expect("Could not start server"); info!("Server started"); server.await?; diff --git a/autoendpoint/src/server/extractors/mod.rs b/autoendpoint/src/server/extractors/mod.rs index ef4aa748..1bb62165 100644 --- a/autoendpoint/src/server/extractors/mod.rs +++ b/autoendpoint/src/server/extractors/mod.rs @@ -3,6 +3,7 @@ pub mod notification; pub mod notification_headers; +pub mod routers; pub mod subscription; pub mod token_info; pub mod user; diff --git a/autoendpoint/src/server/extractors/notification.rs b/autoendpoint/src/server/extractors/notification.rs index e1c8103e..549cd67b 100644 --- a/autoendpoint/src/server/extractors/notification.rs +++ b/autoendpoint/src/server/extractors/notification.rs @@ -7,10 +7,15 @@ use actix_web::web::Data; use actix_web::{FromRequest, HttpRequest}; use autopush_common::util::sec_since_epoch; use cadence::Counted; +use fernet::MultiFernet; use futures::{future, FutureExt, StreamExt}; +use std::collections::HashMap; +use uuid::Uuid; /// Extracts notification data from `Subscription` and request data +#[derive(Clone, Debug)] pub struct Notification { + pub message_id: String, pub subscription: Subscription, pub headers: NotificationHeaders, pub timestamp: u64, @@ -52,9 +57,17 @@ impl FromRequest for Notification { }; let headers = NotificationHeaders::from_request(&req, data.is_some())?; + let timestamp = sec_since_epoch(); + let message_id = Self::generate_message_id( + &state.fernet, + &subscription.user.uaid, + &subscription.channel_id, + headers.topic.as_deref(), + timestamp, + ); // Record the encoding if we have an encrypted payload - if let Some(encoding) = &headers.content_encoding { + if let Some(encoding) = &headers.encoding { if data.is_some() { state .metrics @@ -64,12 +77,94 @@ impl FromRequest for Notification { } Ok(Notification { + message_id, subscription, headers, - timestamp: sec_since_epoch(), + timestamp, data, }) } .boxed_local() } } + +impl From for autopush_common::notification::Notification { + fn from(notification: Notification) -> Self { + autopush_common::notification::Notification { + channel_id: notification.subscription.channel_id, + version: notification.message_id, + ttl: notification.headers.ttl as u64, + topic: notification.headers.topic.clone(), + timestamp: notification.timestamp, + data: notification.data, + sortkey_timestamp: Some(notification.timestamp), + headers: Some(notification.headers.into()), + } + } +} + +impl Notification { + /// Generate a message-id suitable for accessing the message + /// + /// For topic messages, a sort_key version of 01 is used, and the topic + /// is included for reference: + /// + /// Encrypted('01' : uaid.hex : channel_id.hex : topic) + /// + /// For non-topic messages, a sort_key version of 02 is used: + /// + /// Encrypted('02' : uaid.hex : channel_id.hex : timestamp) + fn generate_message_id( + fernet: &MultiFernet, + uaid: &Uuid, + channel_id: &Uuid, + topic: Option<&str>, + timestamp: u64, + ) -> String { + let message_id = if let Some(topic) = topic { + format!( + "01:{}:{}:{}", + uaid.to_simple_ref(), + channel_id.to_simple_ref(), + topic + ) + } else { + format!( + "02:{}:{}:{}", + uaid.to_simple_ref(), + channel_id.to_simple_ref(), + timestamp + ) + }; + + fernet.encrypt(message_id.as_bytes()) + } + + /// Serialize the notification for delivery to the connection server. Some + /// fields in `autopush_common`'s `Notification` are marked with + /// `#[serde(skip_serializing)]` so they are not shown to the UA. These + /// fields are still required when delivering to the connection server, so + /// we can't simply convert this notification type to that one and serialize + /// via serde. + pub fn serialize_for_delivery(&self) -> HashMap<&'static str, serde_json::Value> { + let mut map = HashMap::new(); + + map.insert( + "channelID", + serde_json::to_value(&self.subscription.channel_id).unwrap(), + ); + map.insert("version", serde_json::to_value(&self.message_id).unwrap()); + map.insert("ttl", serde_json::to_value(self.headers.ttl).unwrap()); + map.insert("topic", serde_json::to_value(&self.headers.topic).unwrap()); + map.insert("timestamp", serde_json::to_value(self.timestamp).unwrap()); + + if let Some(data) = &self.data { + map.insert("data", serde_json::to_value(&data).unwrap()); + + let headers: HashMap<_, _> = self.headers.clone().into(); + map.insert("headers", serde_json::to_value(&headers).unwrap()); + } + + map + } +} diff --git a/autoendpoint/src/server/extractors/notification_headers.rs b/autoendpoint/src/server/extractors/notification_headers.rs index b7b7e8f9..e4e8920a 100644 --- a/autoendpoint/src/server/extractors/notification_headers.rs +++ b/autoendpoint/src/server/extractors/notification_headers.rs @@ -5,6 +5,7 @@ use actix_web::HttpRequest; use lazy_static::lazy_static; use regex::Regex; use std::cmp::min; +use std::collections::HashMap; use validator::Validate; use validator_derive::Validate; @@ -15,11 +16,11 @@ lazy_static! { const MAX_TTL: i64 = 60 * 60 * 24 * 60; /// Extractor and validator for notification headers -#[derive(Debug, Eq, PartialEq, Validate)] +#[derive(Clone, Debug, Eq, PartialEq, Validate)] pub struct NotificationHeaders { // TTL is a signed value so that validation can catch negative inputs #[validate(range(min = 0, message = "TTL must be greater than 0", code = "114"))] - pub ttl: Option, + pub ttl: i64, #[validate( length( @@ -37,12 +38,37 @@ pub struct NotificationHeaders { // These fields are validated separately, because the validation is complex // and based upon the content encoding - pub content_encoding: Option, + pub encoding: Option, pub encryption: Option, pub encryption_key: Option, pub crypto_key: Option, } +impl From for HashMap { + fn from(headers: NotificationHeaders) -> Self { + let mut map = HashMap::new(); + + map.insert("ttl".to_string(), headers.ttl.to_string()); + if let Some(h) = headers.topic { + map.insert("topic".to_string(), h); + } + if let Some(h) = headers.encoding { + map.insert("encoding".to_string(), h); + } + if let Some(h) = headers.encryption { + map.insert("encryption".to_string(), h); + } + if let Some(h) = headers.encryption_key { + map.insert("encryption_key".to_string(), h); + } + if let Some(h) = headers.crypto_key { + map.insert("crypto_key".to_string(), h); + } + + map + } +} + impl NotificationHeaders { /// Extract the notification headers from a request. /// This can not be implemented as a `FromRequest` impl because we need to @@ -53,9 +79,10 @@ impl NotificationHeaders { let ttl = get_header(req, "ttl") .and_then(|ttl| ttl.parse().ok()) // Enforce a maximum TTL, but don't error - .map(|ttl| min(ttl, MAX_TTL)); + .map(|ttl| min(ttl, MAX_TTL)) + .ok_or(ApiErrorKind::NoTTL)?; let topic = get_owned_header(req, "topic"); - let content_encoding = get_owned_header(req, "content-encoding"); + let encoding = get_owned_header(req, "content-encoding"); let encryption = get_owned_header(req, "encryption"); let encryption_key = get_owned_header(req, "encryption-key"); let crypto_key = get_owned_header(req, "crypto-key"); @@ -63,7 +90,7 @@ impl NotificationHeaders { let headers = NotificationHeaders { ttl, topic, - content_encoding, + encoding, encryption, encryption_key, crypto_key, @@ -84,11 +111,11 @@ impl NotificationHeaders { /// Validate the encryption headers according to the various WebPush /// standard versions fn validate_encryption(&self) -> ApiResult<()> { - let content_encoding = self.content_encoding.as_deref().ok_or_else(|| { + let encoding = self.encoding.as_deref().ok_or_else(|| { ApiErrorKind::InvalidEncryption("Missing Content-Encoding header".to_string()) })?; - match content_encoding { + match encoding { "aesgcm128" => self.validate_encryption_01_rules()?, "aesgcm" => self.validate_encryption_04_rules()?, "aes128gcm" => self.validate_encryption_06_rules()?, @@ -236,7 +263,7 @@ mod tests { let result = NotificationHeaders::from_request(&req, false); assert!(result.is_ok()); - assert_eq!(result.unwrap().ttl, Some(10)); + assert_eq!(result.unwrap().ttl, 10); } /// Negative TTL values are not allowed @@ -269,13 +296,14 @@ mod tests { let result = NotificationHeaders::from_request(&req, false); assert!(result.is_ok()); - assert_eq!(result.unwrap().ttl, Some(MAX_TTL)); + assert_eq!(result.unwrap().ttl, MAX_TTL); } /// A valid topic results in no errors #[test] fn valid_topic() { let req = TestRequest::post() + .header("TTL", "10") .header("TOPIC", "test-topic") .to_http_request(); let result = NotificationHeaders::from_request(&req, false); @@ -288,6 +316,7 @@ mod tests { #[test] fn too_long_topic() { let req = TestRequest::post() + .header("TTL", "10") .header("TOPIC", "test-topic-which-is-too-long-1234") .to_http_request(); let result = NotificationHeaders::from_request(&req, false); @@ -310,7 +339,7 @@ mod tests { /// If there is a payload, there must be a content encoding header #[test] fn payload_without_content_encoding() { - let req = TestRequest::post().to_http_request(); + let req = TestRequest::post().header("TTL", "10").to_http_request(); let result = NotificationHeaders::from_request(&req, true); assert_encryption_error(result, "Missing Content-Encoding header"); @@ -320,6 +349,7 @@ mod tests { #[test] fn valid_01_encryption() { let req = TestRequest::post() + .header("TTL", "10") .header("Content-Encoding", "aesgcm128") .header("Encryption", "salt=foo") .header("Encryption-Key", "dh=bar") @@ -330,9 +360,9 @@ mod tests { assert_eq!( result.unwrap(), NotificationHeaders { - ttl: None, + ttl: 10, topic: None, - content_encoding: Some("aesgcm128".to_string()), + encoding: Some("aesgcm128".to_string()), encryption: Some("salt=foo".to_string()), encryption_key: Some("dh=bar".to_string()), crypto_key: None @@ -344,6 +374,7 @@ mod tests { #[test] fn valid_04_encryption() { let req = TestRequest::post() + .header("TTL", "10") .header("Content-Encoding", "aesgcm") .header("Encryption", "salt=foo") .header("Crypto-Key", "dh=bar") @@ -354,9 +385,9 @@ mod tests { assert_eq!( result.unwrap(), NotificationHeaders { - ttl: None, + ttl: 10, topic: None, - content_encoding: Some("aesgcm".to_string()), + encoding: Some("aesgcm".to_string()), encryption: Some("salt=foo".to_string()), encryption_key: None, crypto_key: Some("dh=bar".to_string()) @@ -368,6 +399,7 @@ mod tests { #[test] fn valid_06_encryption() { let req = TestRequest::post() + .header("TTL", "10") .header("Content-Encoding", "aes128gcm") .header("Encryption", "notsalt=foo") .header("Crypto-Key", "notdh=bar") @@ -378,9 +410,9 @@ mod tests { assert_eq!( result.unwrap(), NotificationHeaders { - ttl: None, + ttl: 10, topic: None, - content_encoding: Some("aes128gcm".to_string()), + encoding: Some("aes128gcm".to_string()), encryption: Some("notsalt=foo".to_string()), encryption_key: None, crypto_key: Some("notdh=bar".to_string()) diff --git a/autoendpoint/src/server/extractors/routers.rs b/autoendpoint/src/server/extractors/routers.rs new file mode 100644 index 00000000..41405fe9 --- /dev/null +++ b/autoendpoint/src/server/extractors/routers.rs @@ -0,0 +1,48 @@ +use crate::server::extractors::user::RouterType; +use crate::server::routers::webpush::WebPushRouter; +use crate::server::routers::Router; +use crate::server::ServerState; +use actix_web::dev::{Payload, PayloadStream}; +use actix_web::web::Data; +use actix_web::{FromRequest, HttpRequest}; +use futures::future; + +/// Holds the various notification routers. The routers use resources from the +/// server state, which is why `Routers` is an extractor. +pub struct Routers { + pub webpush: WebPushRouter, +} + +impl FromRequest for Routers { + type Error = (); + type Future = future::Ready>; + type Config = (); + + fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { + let state = Data::::extract(&req) + .into_inner() + .expect("No server state found"); + + future::ok(Routers { + webpush: WebPushRouter { + ddb: state.ddb.clone(), + metrics: state.metrics.clone(), + http: state.http.clone(), + endpoint_url: state.settings.endpoint_url.clone(), + }, + }) + } +} + +impl Routers { + /// Get the router which handles the router type + pub fn get(&self, router_type: RouterType) -> &dyn Router { + match router_type { + RouterType::WebPush => &self.webpush, + RouterType::GCM => unimplemented!(), + RouterType::FCM => unimplemented!(), + RouterType::APNS => unimplemented!(), + RouterType::ADM => unimplemented!(), + } + } +} diff --git a/autoendpoint/src/server/extractors/subscription.rs b/autoendpoint/src/server/extractors/subscription.rs index 77520b0a..c09cb42a 100644 --- a/autoendpoint/src/server/extractors/subscription.rs +++ b/autoendpoint/src/server/extractors/subscription.rs @@ -1,6 +1,6 @@ use crate::error::{ApiError, ApiErrorKind, ApiResult}; use crate::server::extractors::token_info::{ApiVersion, TokenInfo}; -use crate::server::extractors::user::validate_user; +use crate::server::extractors::user::{validate_user, RouterType}; use crate::server::headers::crypto_key::CryptoKeyHeader; use crate::server::headers::vapid::{VapidHeader, VapidHeaderWithKey, VapidVersionData}; use crate::server::{ServerState, VapidError}; @@ -19,9 +19,11 @@ use std::borrow::Cow; use uuid::Uuid; /// Extracts subscription data from `TokenInfo` and verifies auth/crypto headers +#[derive(Clone, Debug)] pub struct Subscription { pub user: DynamoDbUser, pub channel_id: Uuid, + pub router_type: RouterType, pub vapid: Option, } @@ -65,7 +67,7 @@ impl FromRequest for Subscription { .await .map_err(ApiErrorKind::Database)? .ok_or(ApiErrorKind::NoUser)?; - validate_user(&user, &channel_id, &state).await?; + let router_type = validate_user(&user, &channel_id, &state).await?; // Validate the VAPID JWT token and record the version if let Some(vapid) = &vapid { @@ -79,6 +81,7 @@ impl FromRequest for Subscription { Ok(Subscription { user, channel_id, + router_type, vapid, }) } diff --git a/autoendpoint/src/server/extractors/user.rs b/autoendpoint/src/server/extractors/user.rs index 61135ed4..d469412d 100644 --- a/autoendpoint/src/server/extractors/user.rs +++ b/autoendpoint/src/server/extractors/user.rs @@ -5,31 +5,59 @@ use crate::server::ServerState; use autopush_common::db::{DynamoDbUser, DynamoStorage}; use cadence::{Counted, StatsdClient}; use futures::compat::Future01CompatExt; +use std::str::FromStr; use uuid::Uuid; /// Valid `DynamoDbUser::router_type` values -const VALID_ROUTERS: [&str; 5] = ["webpush", "gcm", "fcm", "apns", "adm"]; +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum RouterType { + WebPush, + GCM, + FCM, + APNS, + ADM, +} + +impl FromStr for RouterType { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "webpush" => Ok(RouterType::WebPush), + "gcm" => Ok(RouterType::GCM), + "fcm" => Ok(RouterType::FCM), + "apns" => Ok(RouterType::APNS), + "adm" => Ok(RouterType::ADM), + _ => Err(()), + } + } +} /// Perform some validations on the user, including: /// - Validate router type /// - (WebPush) Check that the subscription/channel exists /// - (WebPush) Drop user if inactive +/// +/// Returns an enum representing the user's router type. pub async fn validate_user( user: &DynamoDbUser, channel_id: &Uuid, state: &ServerState, -) -> ApiResult<()> { - if !VALID_ROUTERS.contains(&user.router_type.as_str()) { - debug!("Unknown router type, dropping user"; "user" => ?user); - drop_user(&user.uaid, &state.ddb, &state.metrics).await?; - return Err(ApiErrorKind::NoSubscription.into()); - } +) -> ApiResult { + let router_type = match user.router_type.parse::() { + Ok(router_type) => router_type, + Err(_) => { + debug!("Unknown router type, dropping user"; "user" => ?user); + drop_user(&user.uaid, &state.ddb, &state.metrics).await?; + return Err(ApiErrorKind::NoSubscription.into()); + } + }; - if user.router_type == "webpush" { + if router_type == RouterType::WebPush { validate_webpush_user(user, channel_id, &state.ddb, &state.metrics).await?; } - Ok(()) + Ok(router_type) } /// Make sure the user is not inactive and the subscription channel exists diff --git a/autoendpoint/src/server/headers/vapid.rs b/autoendpoint/src/server/headers/vapid.rs index 2d00eb74..9bdcb74b 100644 --- a/autoendpoint/src/server/headers/vapid.rs +++ b/autoendpoint/src/server/headers/vapid.rs @@ -5,7 +5,7 @@ use thiserror::Error; const ALLOWED_SCHEMES: [&str; 3] = ["bearer", "webpush", "vapid"]; /// Parses the VAPID authorization header -#[derive(Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq)] pub struct VapidHeader { pub scheme: String, pub token: String, @@ -14,6 +14,7 @@ pub struct VapidHeader { /// Combines the VAPID header details with the public key, which may not be from /// the VAPID header +#[derive(Clone, Debug)] pub struct VapidHeaderWithKey { pub vapid: VapidHeader, pub public_key: String, diff --git a/autoendpoint/src/server/mod.rs b/autoendpoint/src/server/mod.rs index 420f881d..29683687 100644 --- a/autoendpoint/src/server/mod.rs +++ b/autoendpoint/src/server/mod.rs @@ -18,9 +18,11 @@ use std::sync::Arc; mod extractors; mod headers; +mod routers; mod routes; pub use headers::vapid::VapidError; +pub use routers::RouterError; #[derive(Clone)] pub struct ServerState { @@ -29,6 +31,7 @@ pub struct ServerState { pub settings: Settings, pub fernet: Arc, pub ddb: DynamoStorage, + pub http: reqwest::Client, } pub struct Server; @@ -44,11 +47,13 @@ impl Server { metrics.clone(), ) .map_err(ApiErrorKind::Database)?; + let http = reqwest::Client::new(); let state = ServerState { metrics, settings, fernet, ddb, + http, }; let server = HttpServer::new(move || { diff --git a/autoendpoint/src/server/routers/mod.rs b/autoendpoint/src/server/routers/mod.rs new file mode 100644 index 00000000..3d617472 --- /dev/null +++ b/autoendpoint/src/server/routers/mod.rs @@ -0,0 +1,63 @@ +//! Routers route notifications to user agents + +use crate::error::ApiResult; +use crate::server::extractors::notification::Notification; +use actix_web::http::StatusCode; +use actix_web::HttpResponse; +use async_trait::async_trait; +use std::collections::HashMap; +use thiserror::Error; + +pub mod webpush; + +#[async_trait(?Send)] +pub trait Router { + /// Route a notification to the user + async fn route_notification(&self, notification: &Notification) -> ApiResult; +} + +/// The response returned when a router routes a notification +pub struct RouterResponse { + pub status: StatusCode, + pub headers: HashMap<&'static str, String>, + pub body: Option, +} + +impl From for HttpResponse { + fn from(router_response: RouterResponse) -> Self { + let mut builder = HttpResponse::build(router_response.status); + + for (key, value) in router_response.headers { + builder.set_header(key, value); + } + + builder.body(router_response.body.unwrap_or_default()) + } +} + +#[derive(Debug, Error)] +pub enum RouterError { + #[error("Database error while saving notification")] + SaveDb(#[source] autopush_common::errors::Error), + + #[error("User was deleted during routing")] + UserWasDeleted, +} + +impl RouterError { + /// Get the associated HTTP status code + pub fn status(&self) -> StatusCode { + match self { + RouterError::SaveDb(_) => StatusCode::SERVICE_UNAVAILABLE, + RouterError::UserWasDeleted => StatusCode::GONE, + } + } + + /// Get the associated error number + pub fn errno(&self) -> usize { + match self { + RouterError::SaveDb(_) => 201, + RouterError::UserWasDeleted => 105, + } + } +} diff --git a/autoendpoint/src/server/routers/webpush.rs b/autoendpoint/src/server/routers/webpush.rs new file mode 100644 index 00000000..e7025184 --- /dev/null +++ b/autoendpoint/src/server/routers/webpush.rs @@ -0,0 +1,212 @@ +//! The router for desktop user agents. +//! +//! These agents are connected via an Autopush connection server. The correct +//! server is located via the database routing table. If the server is busy or +//! not available, the notification is stored in the database. + +use crate::error::{ApiErrorKind, ApiResult}; +use crate::server::extractors::notification::Notification; +use crate::server::routers::{Router, RouterResponse}; +use crate::server::RouterError; +use async_trait::async_trait; +use autopush_common::db::{DynamoDbUser, DynamoStorage}; +use cadence::{Counted, StatsdClient}; +use futures::compat::Future01CompatExt; +use reqwest::{Response, StatusCode}; +use std::collections::HashMap; +use url::Url; +use uuid::Uuid; + +pub struct WebPushRouter { + pub ddb: DynamoStorage, + pub metrics: StatsdClient, + pub http: reqwest::Client, + pub endpoint_url: Url, +} + +#[async_trait(?Send)] +impl Router for WebPushRouter { + async fn route_notification(&self, notification: &Notification) -> ApiResult { + let user = ¬ification.subscription.user; + debug!( + "Routing WebPush notification to UAID {}", + notification.subscription.user.uaid + ); + trace!("Notification = {:?}", notification); + + // Check if there is a node connected to the client + if let Some(node_id) = &user.node_id { + trace!("User has a node ID, sending notification to node"); + + // Try to send the notification to the node + match self.send_notification(notification, node_id).await { + Ok(response) => { + // The node might be busy, make sure it accepted the notification + if response.status() == 200 { + // The node has received the notification + trace!("Node received notification"); + return Ok(self.make_delivered_response(notification)); + } + + trace!( + "Node did not receive the notification, response = {:?}", + response + ); + } + Err(error) => { + // We should stop sending notifications to this node for this user + debug!("Error while sending webpush notification: {}", error); + self.remove_node_id(user, node_id.clone()).await?; + } + } + } + + // Save notification, node is not present or busy + trace!("Node is not present or busy, storing notification"); + self.store_notification(notification).await?; + + // Retrieve the user data again, they may have reconnected or the node + // is no longer busy. + trace!("Re-fetching user to trigger notification check"); + let user = match self.ddb.get_user(&user.uaid).compat().await { + Ok(Some(user)) => user, + Ok(None) => { + trace!("No user found, must have been deleted"); + return Err(ApiErrorKind::Router(RouterError::UserWasDeleted).into()); + } + Err(e) => { + // Database error, but we already stored the message so it's ok + debug!("Database error while re-fetching user: {}", e); + return Ok(self.make_stored_response(notification)); + } + }; + + // Try to notify the node the user is currently connected to + let node_id = match &user.node_id { + Some(id) => id, + // The user is not connected to a node, nothing more to do + None => { + trace!("User is not connected to a node, returning stored response"); + return Ok(self.make_stored_response(notification)); + } + }; + + // Notify the node to check for messages + trace!("Notifying node to check for messages"); + match self.trigger_notification_check(&user.uaid, &node_id).await { + Ok(response) => { + trace!("Response = {:?}", response); + if response.status() == 200 { + trace!("Node has delivered the message"); + Ok(self.make_delivered_response(notification)) + } else { + trace!("Node has not delivered the message, returning stored response"); + Ok(self.make_stored_response(notification)) + } + } + Err(error) => { + // Can't communicate with the node, so we should stop using it + debug!("Error while triggering notification check: {}", error); + self.remove_node_id(&user, node_id.clone()).await?; + Ok(self.make_stored_response(notification)) + } + } + } +} + +impl WebPushRouter { + /// Send the notification to the node + async fn send_notification( + &self, + notification: &Notification, + node_id: &str, + ) -> Result { + let url = format!("{}/push/{}", node_id, notification.subscription.user.uaid); + let notification = notification.serialize_for_delivery(); + + self.http.put(&url).json(¬ification).send().await + } + + /// Notify the node to check for notifications for the user + async fn trigger_notification_check( + &self, + uaid: &Uuid, + node_id: &str, + ) -> Result { + let url = format!("{}/notif/{}", node_id, uaid); + + self.http.put(&url).send().await + } + + /// Store a notification in the database + async fn store_notification(&self, notification: &Notification) -> ApiResult<()> { + self.ddb + .store_message( + ¬ification.subscription.user.uaid, + notification + .subscription + .user + .current_month + .clone() + .unwrap_or_else(|| self.ddb.current_message_month.clone()), + notification.clone().into(), + ) + .compat() + .await + .map_err(|e| ApiErrorKind::Router(RouterError::SaveDb(e)).into()) + } + + /// Remove the node ID from a user. This is done if the user is no longer + /// connected to the node. + async fn remove_node_id(&self, user: &DynamoDbUser, node_id: String) -> ApiResult<()> { + self.metrics.incr("updates.client.host_gone").ok(); + + self.ddb + .remove_node_id(&user.uaid, node_id, user.connected_at) + .compat() + .await + .map_err(|e| ApiErrorKind::Database(e).into()) + } + + /// Update metrics and create a response for when a notification has been directly forwarded to + /// an autopush server. + fn make_delivered_response(&self, notification: &Notification) -> RouterResponse { + self.make_response(notification, "Direct", StatusCode::OK) + } + + /// Update metrics and create a response for when a notification has been stored in the database + /// for future transmission. + fn make_stored_response(&self, notification: &Notification) -> RouterResponse { + self.make_response(notification, "Stored", StatusCode::ACCEPTED) + } + + /// Update metrics and create a response after routing a notification + fn make_response( + &self, + notification: &Notification, + destination_tag: &str, + status: StatusCode, + ) -> RouterResponse { + self.metrics + .count_with_tags( + "notification.message_data", + notification.data.as_ref().map(String::len).unwrap_or(0) as i64, + ) + .with_tag("destination", destination_tag) + .send(); + + RouterResponse { + status, + headers: { + let mut map = HashMap::new(); + map.insert( + "Location", + format!("{}/m/{}", self.endpoint_url, notification.message_id), + ); + map.insert("TTL", notification.headers.ttl.to_string()); + map + }, + body: None, + } + } +} diff --git a/autoendpoint/src/server/routes/webpush.rs b/autoendpoint/src/server/routes/webpush.rs index c87f00c3..80c8d567 100644 --- a/autoendpoint/src/server/routes/webpush.rs +++ b/autoendpoint/src/server/routes/webpush.rs @@ -1,7 +1,16 @@ +use crate::error::ApiResult; use crate::server::extractors::notification::Notification; +use crate::server::extractors::routers::Routers; use actix_web::HttpResponse; /// Handle the `/wpush/{api_version}/{token}` and `/wpush/{token}` routes -pub async fn webpush_route(_notification: Notification) -> HttpResponse { - HttpResponse::Ok().finish() +pub async fn webpush_route( + notification: Notification, + routers: Routers, +) -> ApiResult { + let router = routers.get(notification.subscription.router_type); + + let response = router.route_notification(¬ification).await?; + + Ok(response.into()) } diff --git a/autoendpoint/src/settings.rs b/autoendpoint/src/settings.rs index a9986dc5..42f4ec4a 100644 --- a/autoendpoint/src/settings.rs +++ b/autoendpoint/src/settings.rs @@ -14,10 +14,7 @@ pub struct Settings { pub debug: bool, pub port: u16, pub host: String, - pub database_url: String, - pub database_pool_max_size: Option, - #[cfg(any(test, feature = "db_test"))] - pub database_use_test_transactions: bool, + pub endpoint_url: Url, pub router_table_name: String, pub message_table_name: String, @@ -37,10 +34,7 @@ impl Default for Settings { debug: false, port: DEFAULT_PORT, host: "127.0.0.1".to_string(), - database_url: "mysql://root@127.0.0.1/autopush".to_string(), - database_pool_max_size: None, - #[cfg(any(test, feature = "db_test"))] - database_use_test_transactions: false, + endpoint_url: Url::parse("http://127.0.0.1:8000/").unwrap(), router_table_name: "router".to_string(), message_table_name: "message".to_string(), max_data_bytes: 4096, @@ -86,14 +80,6 @@ impl Settings { }) } - /// A simple banner for display of certain settings at startup - pub fn banner(&self) -> String { - let db = Url::parse(&self.database_url) - .map(|url| url.scheme().to_owned()) - .unwrap_or_else(|_| "".to_owned()); - format!("http://{}:{} ({})", self.host, self.port, db) - } - /// Initialize the fernet encryption instance pub fn make_fernet(&self) -> MultiFernet { if !(self.crypto_keys.starts_with('[') && self.crypto_keys.ends_with(']')) { diff --git a/autopush-common/src/db/mod.rs b/autopush-common/src/db/mod.rs index 291d508d..ec01314e 100644 --- a/autopush-common/src/db/mod.rs +++ b/autopush-common/src/db/mod.rs @@ -8,8 +8,8 @@ use futures_backoff::retry_if; use rusoto_core::{HttpClient, Region}; use rusoto_credential::StaticProvider; use rusoto_dynamodb::{ - AttributeValue, BatchWriteItemInput, DeleteItemInput, DynamoDb, DynamoDbClient, PutRequest, - UpdateItemInput, UpdateItemOutput, WriteRequest, + AttributeValue, BatchWriteItemInput, DeleteItemInput, DynamoDb, DynamoDbClient, PutItemInput, + PutRequest, UpdateItemInput, UpdateItemOutput, WriteRequest, }; #[macro_use] @@ -23,8 +23,8 @@ use crate::notification::Notification; use crate::util::timing::sec_since_epoch; use self::commands::{ - retryable_batchwriteitem_error, retryable_delete_error, retryable_updateitem_error, - FetchMessageResponse, + retryable_batchwriteitem_error, retryable_delete_error, retryable_putitem_error, + retryable_updateitem_error, FetchMessageResponse, }; pub use self::models::{DynamoDbNotification, DynamoDbUser}; @@ -310,6 +310,29 @@ impl DynamoStorage { .chain_err(|| "Unable to migrate user") } + /// Store a single message + pub fn store_message( + &self, + uaid: &Uuid, + message_month: String, + message: Notification, + ) -> impl Future { + let ddb = self.ddb.clone(); + let put_item = PutItemInput { + item: serde_dynamodb::to_hashmap(&DynamoDbNotification::from_notif(uaid, message)) + .unwrap(), + table_name: message_month, + ..Default::default() + }; + + retry_if( + move || ddb.put_item(put_item.clone()), + retryable_putitem_error, + ) + .and_then(|_| future::ok(())) + .chain_err(|| "Error saving notification") + } + /// Store a batch of messages when shutting down pub fn store_messages( &self, @@ -472,6 +495,36 @@ impl DynamoStorage { .collect::>() }) } + + /// Remove the node ID from a user in the router table. + /// The node ID will only be cleared if `connected_at` matches up + /// with the item's `connected_at`. + pub fn remove_node_id( + &self, + uaid: &Uuid, + node_id: String, + connected_at: u64, + ) -> impl Future { + let ddb = self.ddb.clone(); + let update_item = UpdateItemInput { + key: ddb_item! { uaid: s => uaid.to_simple().to_string() }, + update_expression: Some("REMOVE node_id".to_string()), + condition_expression: Some("(node_id = :node) and (connected_at = :conn)".to_string()), + expression_attribute_values: Some(hashmap! { + ":node".to_string() => val!(S => node_id), + ":conn".to_string() => val!(N => connected_at.to_string()) + }), + table_name: self.router_table_name.clone(), + ..Default::default() + }; + + retry_if( + move || ddb.update_item(update_item.clone()), + retryable_updateitem_error, + ) + .and_then(|_| future::ok(())) + .chain_err(|| "Error removing node ID") + } } pub fn list_message_tables(ddb: &DynamoDbClient, prefix: &str) -> Result> {