From a914bc5bcc60bd9d5dc37c49ebc1e9b1b6bfda52 Mon Sep 17 00:00:00 2001 From: Thomas Farr Date: Tue, 7 Nov 2023 10:30:17 +1300 Subject: [PATCH] Update aws-sigv4 to `>= 0.57` (#201) * cargo fmt Signed-off-by: Thomas Farr * Update aws-sigv4 to `>= 0.57` - Expose signing time source as configurable for testability. - Explicitly add global default headers upfront rather than via reqwest::Client so that they are taken into account for signing. - Add test cases for sigv4 based off those in opensearch-net. Signed-off-by: Thomas Farr * Changelog Signed-off-by: Thomas Farr --------- Signed-off-by: Thomas Farr --- CHANGELOG.md | 2 + opensearch/Cargo.toml | 17 ++- opensearch/examples/advanced_index_actions.rs | 28 +++-- opensearch/examples/index_lifecycle.rs | 20 +++- opensearch/examples/index_template.rs | 27 +++-- opensearch/examples/json.rs | 61 ++++++---- opensearch/src/http/aws_auth.rs | 105 ++++++++++-------- opensearch/src/http/transport.rs | 46 ++++++-- opensearch/src/lib.rs | 4 +- opensearch/tests/auth.rs | 12 +- opensearch/tests/aws_auth.rs | 92 +++++++++++++-- opensearch/tests/client.rs | 12 +- opensearch/tests/common/mod.rs | 23 +++- opensearch/tests/common/server.rs | 19 +++- 14 files changed, 329 insertions(+), 139 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c8a507e6..1391d12d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ### Added - Added InfoResponse structure ([#187](https://github.com/opensearch-project/opensearch-rs/pull/187)) - Added documentation on how to make raw json requests ([#196](https://github.com/opensearch-project/opensearch-rs/pull/196)) + ### Dependencies - Bumps `sysinfo` from 0.28.0 to 0.29.0 - Bumps `serde_with` from ~2 to ~3 @@ -13,6 +14,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Bumps `syn` from 1.0 to 2.0 - Bumps `toml` from 0.7.1 to 0.8.0 - Bumps `dialoguer` from 0.10.2 to 0.11.0 +- Bumps `aws-*` from >=0.53 to >=0.57 ([#201](https://github.com/opensearch-project/opensearch-rs/pull/201)) ### Changed - Moved @aditjind to Emeritus maintainers ([#170](https://github.com/opensearch-project/opensearch-rs/pull/170)) diff --git a/opensearch/Cargo.toml b/opensearch/Cargo.toml index ec21556b..cefd5253 100644 --- a/opensearch/Cargo.toml +++ b/opensearch/Cargo.toml @@ -26,7 +26,7 @@ native-tls = ["reqwest/native-tls"] rustls-tls = ["reqwest/rustls-tls"] # AWS SigV4 Auth support -aws-auth = ["aws-credential-types", "aws-sigv4", "aws-types"] +aws-auth = ["aws-credential-types", "aws-sigv4", "aws-smithy-runtime-api", "aws-types"] [dependencies] base64 = "0.21" @@ -40,13 +40,15 @@ serde = { version = "1", features = ["derive"] } serde_json = "1" serde_with = "3" void = "1.0.2" -aws-credential-types = { version = ">= 0.53", optional = true } -aws-sigv4 = { version = ">= 0.53", optional = true } -aws-types = { version = ">= 0.53", optional = true } +aws-credential-types = { version = ">= 0.57", optional = true } +aws-sigv4 = { version = ">= 0.57", optional = true } +aws-smithy-runtime-api = { version = ">= 0.57", optional = true, features = ["client"]} +aws-types = { version = ">= 0.57", optional = true } [dev-dependencies] anyhow = "1.0" -aws-config = ">= 0.53" +aws-config = ">= 0.57" +aws-smithy-async = ">= 0.57" chrono = { version = "0.4", features = ["serde"] } clap = "2" futures = "0.3.1" @@ -54,8 +56,11 @@ http = "0.2" hyper = { version = "0.14", default-features = false, features = ["tcp", "stream", "server"] } regex="1.4" sysinfo = "0.29.0" +test-case = "3" textwrap = "0.16" -tokio = { version = "1.0", default-features = false, features = ["macros", "net", "time", "rt-multi-thread"] } +tokio = { version = "1.0", default-features = false, features = ["macros", "net", "time", "rt-multi-thread", "sync"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } xml-rs = "0.8" [build-dependencies] diff --git a/opensearch/examples/advanced_index_actions.rs b/opensearch/examples/advanced_index_actions.rs index 8310b587..b94bba50 100644 --- a/opensearch/examples/advanced_index_actions.rs +++ b/opensearch/examples/advanced_index_actions.rs @@ -1,12 +1,24 @@ -use opensearch::auth::Credentials; -use opensearch::indices::{ - IndicesAddBlockParts, IndicesClearCacheParts, IndicesCloneParts, IndicesCloseParts, - IndicesCreateParts, IndicesDeleteParts, IndicesFlushParts, IndicesForcemergeParts, - IndicesOpenParts, IndicesPutSettingsParts, IndicesRefreshParts, IndicesSplitParts, -}; +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + use opensearch::{ - cert::CertificateValidation, http::transport::SingleNodeConnectionPool, - http::transport::TransportBuilder, OpenSearch, + auth::Credentials, + cert::CertificateValidation, + http::transport::{SingleNodeConnectionPool, TransportBuilder}, + indices::{ + IndicesAddBlockParts, IndicesClearCacheParts, IndicesCloneParts, IndicesCloseParts, + IndicesCreateParts, IndicesDeleteParts, IndicesFlushParts, IndicesForcemergeParts, + IndicesOpenParts, IndicesPutSettingsParts, IndicesRefreshParts, IndicesSplitParts, + }, + OpenSearch, }; use serde_json::json; use url::Url; diff --git a/opensearch/examples/index_lifecycle.rs b/opensearch/examples/index_lifecycle.rs index 2673a7ed..879c582e 100644 --- a/opensearch/examples/index_lifecycle.rs +++ b/opensearch/examples/index_lifecycle.rs @@ -1,13 +1,23 @@ -use opensearch::auth::Credentials; -use opensearch::cert::CertificateValidation; -use opensearch::http::transport::{SingleNodeConnectionPool, TransportBuilder}; -use opensearch::OpenSearch; +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + use opensearch::{ + auth::Credentials, + cert::CertificateValidation, + http::transport::{SingleNodeConnectionPool, TransportBuilder}, indices::{ IndicesCreateParts, IndicesDeleteParts, IndicesExistsParts, IndicesGetParts, IndicesPutMappingParts, IndicesPutSettingsParts, }, - IndexParts, + IndexParts, OpenSearch, }; use serde_json::json; use url::Url; diff --git a/opensearch/examples/index_template.rs b/opensearch/examples/index_template.rs index 102124bf..2a8a7801 100644 --- a/opensearch/examples/index_template.rs +++ b/opensearch/examples/index_template.rs @@ -1,12 +1,23 @@ -use opensearch::auth::Credentials; -use opensearch::cluster::{ClusterDeleteComponentTemplateParts, ClusterPutComponentTemplateParts}; -use opensearch::indices::{ - IndicesCreateParts, IndicesDeleteIndexTemplateParts, IndicesDeleteParts, - IndicesGetIndexTemplateParts, IndicesGetParts, IndicesPutIndexTemplateParts, -}; +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ use opensearch::{ - cert::CertificateValidation, http::transport::SingleNodeConnectionPool, - http::transport::TransportBuilder, OpenSearch, + auth::Credentials, + cert::CertificateValidation, + cluster::{ClusterDeleteComponentTemplateParts, ClusterPutComponentTemplateParts}, + http::transport::{SingleNodeConnectionPool, TransportBuilder}, + indices::{ + IndicesCreateParts, IndicesDeleteIndexTemplateParts, IndicesDeleteParts, + IndicesGetIndexTemplateParts, IndicesGetParts, IndicesPutIndexTemplateParts, + }, + OpenSearch, }; use serde_json::json; use url::Url; diff --git a/opensearch/examples/json.rs b/opensearch/examples/json.rs index 2b3395f2..d8767c37 100644 --- a/opensearch/examples/json.rs +++ b/opensearch/examples/json.rs @@ -1,13 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +use opensearch::{ + auth::Credentials, + cert::CertificateValidation, + http::{ + headers::HeaderMap, + request::JsonBody, + transport::{SingleNodeConnectionPool, TransportBuilder}, + Method, Url, + }, + OpenSearch, +}; use serde_json::{json, Value}; -use opensearch::auth::Credentials; -use opensearch::cert::CertificateValidation; -use opensearch::http::headers::HeaderMap; -use opensearch::http::transport::{SingleNodeConnectionPool, TransportBuilder}; -use opensearch::http::{Method, Url}; -use opensearch::http::request::JsonBody; -use opensearch::OpenSearch; - #[tokio::main] async fn main() -> Result<(), Box> { let url = Url::parse("https://localhost:9200")?; @@ -22,14 +36,7 @@ async fn main() -> Result<(), Box> { let document_id = "1"; let info: Value = client - .send::<(), ()>( - Method::Get, - "/", - HeaderMap::new(), - None, - None, - None, - ) + .send::<(), ()>(Method::Get, "/", HeaderMap::new(), None, None, None) .await? .json() .await?; @@ -40,13 +47,14 @@ async fn main() -> Result<(), Box> { ); // Create an index - let index_body : JsonBody<_> = json!({ + let index_body: JsonBody<_> = json!({ "settings": { "index": { "number_of_shards" : 4 } } - }).into(); + }) + .into(); let create_index_response = client .send( @@ -66,7 +74,8 @@ async fn main() -> Result<(), Box> { "title": "Moneyball", "director": "Bennett Miller", "year": "2011" - }).into(); + }) + .into(); let create_document_response = client .send( Method::Put, @@ -82,7 +91,7 @@ async fn main() -> Result<(), Box> { // Search for a document let q = "miller"; - let query : JsonBody<_> = json!({ + let query: JsonBody<_> = json!({ "size": 5, "query": { "multi_match": { @@ -90,7 +99,8 @@ async fn main() -> Result<(), Box> { "fields": ["title^2", "director"] } } - }).into(); + }) + .into(); let search_response = client .send( @@ -105,11 +115,14 @@ async fn main() -> Result<(), Box> { assert_eq!(search_response.status_code(), 200); let search_result = search_response.json::().await?; - println!("Hits: {:#?}", search_result["hits"]["hits"].as_array().unwrap()); + println!( + "Hits: {:#?}", + search_result["hits"]["hits"].as_array().unwrap() + ); // Delete the document let delete_document_response = client - .send::<(),()>( + .send::<(), ()>( Method::Delete, &format!("/{index_name}/_doc/{document_id}"), HeaderMap::new(), @@ -123,7 +136,7 @@ async fn main() -> Result<(), Box> { // Delete the index let delete_response = client - .send::<(),()>( + .send::<(), ()>( Method::Delete, &format!("/{index_name}"), HeaderMap::new(), diff --git a/opensearch/src/http/aws_auth.rs b/opensearch/src/http/aws_auth.rs index ffa33c76..6150ffd0 100644 --- a/opensearch/src/http/aws_auth.rs +++ b/opensearch/src/http/aws_auth.rs @@ -9,70 +9,83 @@ * GitHub history for details. */ -use std::time::SystemTime; - -use aws_credential_types::{ - provider::{ProvideCredentials, SharedCredentialsProvider}, - Credentials, -}; +use crate::http::headers::HeaderValue; +use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider}; use aws_sigv4::{ http_request::{ sign, PayloadChecksumKind, SignableBody, SignableRequest, SigningParams, SigningSettings, }, - signing_params::BuildError, + sign::v4, }; -use aws_types::region::Region; +use aws_smithy_runtime_api::client::identity::Identity; +use aws_types::{region::Region, sdk_config::SharedTimeSource}; use reqwest::Request; -fn get_signing_params<'a>( - credentials: &'a Credentials, - service_name: &'a str, - region: &'a Region, -) -> Result, BuildError> { - let mut signing_settings = SigningSettings::default(); - signing_settings.payload_checksum_kind = PayloadChecksumKind::XAmzSha256; // required for OpenSearch Serverless - - let mut builder = SigningParams::builder() - .access_key(credentials.access_key_id()) - .secret_key(credentials.secret_access_key()) - .service_name(service_name) - .region(region.as_ref()) - .time(SystemTime::now()) - .settings(signing_settings); - - builder.set_security_token(credentials.session_token()); - - builder.build() -} - pub async fn sign_request( request: &mut Request, credentials_provider: &SharedCredentialsProvider, service_name: &str, region: &Region, + time_source: &SharedTimeSource, ) -> Result<(), Box> { - let credentials = credentials_provider.provide_credentials().await?; + let identity = { + let c = credentials_provider.provide_credentials().await?; + let e = c.expiry(); + Identity::new(c, e) + }; + + let signing_settings = { + let mut s = SigningSettings::default(); + s.payload_checksum_kind = PayloadChecksumKind::XAmzSha256; // required for OpenSearch Serverless + s + }; + + let params = { + let p = v4::SigningParams::builder() + .identity(&identity) + .name(service_name) + .region(region.as_ref()) + .time(time_source.now()) + .settings(signing_settings) + .build()?; + SigningParams::V4(p) + }; - let params = get_signing_params(&credentials, service_name, region)?; + let signable_request = { + let method = request.method().as_str(); + let uri = request.url().as_str(); + let headers = request.headers().iter().map(|(k, v)| { + ( + k.as_str(), + std::str::from_utf8(v.as_bytes()).expect("only utf-8 headers are signable"), + ) + }); + let body = match request.body() { + Some(b) => match b.as_bytes() { + Some(bytes) => SignableBody::Bytes(bytes), + None => SignableBody::UnsignedPayload, // Body is not in memory (ie. streaming), so we can't sign it + }, + None => SignableBody::Bytes(&[]), + }; - let uri = request.url().as_str().parse()?; + SignableRequest::new(method, uri, headers, body)? + }; - let signable_request = SignableRequest::new( - request.method(), - &uri, - request.headers(), - SignableBody::Bytes(request.body().and_then(|b| b.as_bytes()).unwrap_or(&[])), - ); + let (new_headers, new_query_params) = { + let (instructions, _) = sign(signable_request, ¶ms)?.into_parts(); + instructions.into_parts() + }; - let (mut instructions, _) = sign(signable_request, ¶ms)?.into_parts(); + for header in new_headers.into_iter() { + let mut value = HeaderValue::from_str(header.value()) + .expect("AWS signing header value must be a valid header"); + value.set_sensitive(header.sensitive()); + + request.headers_mut().insert(header.name(), value); + } - if let Some(new_headers) = instructions.take_headers() { - for (name, value) in new_headers.into_iter() { - request.headers_mut().insert( - name.expect("AWS signing header name must never be None"), - value, - ); - } + for (key, value) in new_query_params.into_iter() { + request.url_mut().query_pairs_mut().append_pair(key, &value); } Ok(()) diff --git a/opensearch/src/http/transport.rs b/opensearch/src/http/transport.rs index f70c6eb1..ca981bd2 100644 --- a/opensearch/src/http/transport.rs +++ b/opensearch/src/http/transport.rs @@ -47,6 +47,8 @@ use crate::{ Method, }, }; +#[cfg(feature = "aws-auth")] +use aws_types::sdk_config::SharedTimeSource; use base64::{prelude::BASE64_STANDARD, write::EncoderWriter as Base64Encoder}; use bytes::BytesMut; use dyn_clone::clone_trait_object; @@ -155,7 +157,9 @@ pub struct TransportBuilder { headers: HeaderMap, timeout: Option, #[cfg(feature = "aws-auth")] - service_name: String, + sigv4_service_name: String, + #[cfg(feature = "aws-auth")] + sigv4_time_source: Option, } impl TransportBuilder { @@ -177,7 +181,9 @@ impl TransportBuilder { headers: HeaderMap::new(), timeout: None, #[cfg(feature = "aws-auth")] - service_name: "es".to_string(), + sigv4_service_name: "es".to_string(), + #[cfg(feature = "aws-auth")] + sigv4_time_source: None, } } @@ -245,12 +251,21 @@ impl TransportBuilder { self } - /// Sets a global AWS service name. + /// Sets the AWS SigV4 signing service name. /// /// Default is "es". Other supported services are "aoss" for OpenSearch Serverless. #[cfg(feature = "aws-auth")] pub fn service_name(mut self, service_name: &str) -> Self { - self.service_name = service_name.to_string(); + self.sigv4_service_name = service_name.to_string(); + self + } + + /// Sets the AWS SigV4 signing time source. + /// + /// Default is `SystemTimeSource` + #[cfg(feature = "aws-auth")] + pub fn sigv4_time_source(mut self, sigv4_time_source: SharedTimeSource) -> Self { + self.sigv4_time_source = Some(sigv4_time_source); self } @@ -258,10 +273,6 @@ impl TransportBuilder { pub fn build(self) -> Result { let mut client_builder = self.client_builder; - if !self.headers.is_empty() { - client_builder = client_builder.default_headers(self.headers); - } - if let Some(t) = self.timeout { client_builder = client_builder.timeout(t); } @@ -326,8 +337,11 @@ impl TransportBuilder { client, conn_pool: self.conn_pool, credentials: self.credentials, + default_headers: self.headers, #[cfg(feature = "aws-auth")] - service_name: self.service_name, + sigv4_service_name: self.sigv4_service_name, + #[cfg(feature = "aws-auth")] + sigv4_time_source: self.sigv4_time_source.unwrap_or_default(), }) } } @@ -367,8 +381,11 @@ pub struct Transport { client: reqwest::Client, credentials: Option, conn_pool: Box, + default_headers: HeaderMap, + #[cfg(feature = "aws-auth")] + sigv4_service_name: String, #[cfg(feature = "aws-auth")] - service_name: String, + sigv4_time_source: SharedTimeSource, } impl Transport { @@ -446,10 +463,14 @@ impl Transport { } // default headers first, overwrite with any provided - let mut request_headers = HeaderMap::with_capacity(4 + headers.len()); + let mut request_headers = + HeaderMap::with_capacity(4 + self.default_headers.len() + headers.len()); request_headers.insert(CONTENT_TYPE, HeaderValue::from_static(DEFAULT_CONTENT_TYPE)); request_headers.insert(ACCEPT, HeaderValue::from_static(DEFAULT_ACCEPT)); request_headers.insert(USER_AGENT, HeaderValue::from_static(DEFAULT_USER_AGENT)); + for (name, value) in self.default_headers.iter() { + request_headers.insert(name, value.clone()); + } for (name, value) in headers { request_headers.insert(name.unwrap(), value); } @@ -480,8 +501,9 @@ impl Transport { super::aws_auth::sign_request( &mut request, credentials_provider, - &self.service_name, + &self.sigv4_service_name, region, + &self.sigv4_time_source, ) .await .map_err(|e| crate::error::lib(format!("AWSV4 Signing Failed: {}", e)))?; diff --git a/opensearch/src/lib.rs b/opensearch/src/lib.rs index 28742e5a..226f6324 100644 --- a/opensearch/src/lib.rs +++ b/opensearch/src/lib.rs @@ -27,7 +27,7 @@ //! |-------------|---------------| //! | 1.x | 1.x | //! | 2.x | 2.x, 1.x^ | -//! - ^: With the exception of some previously deprecated APIs +//! - ^: With the exception of some previously deprecated APIs //! //! A major version of the client is compatible with the same major version of OpenSearch. //! Since OpenSearch is developed following [Semantic Versioning](https://semver.org/) principles, @@ -44,7 +44,7 @@ //! In the latter case, a 1.4.0 client won't contain API functions for APIs that are introduced in //! OpenSearch 1.5.0+, but for all other APIs available in OpenSearch, the respective API //! functions on the client will be compatible. -//! +//! //! In some instances, a new major version of OpenSearch may remain compatible with an //! older major version of the client, which may not warrant a need to update the client. //! Please consult COMPATIBILITY.md for more details. diff --git a/opensearch/tests/auth.rs b/opensearch/tests/auth.rs index 031be9c2..4967fed7 100644 --- a/opensearch/tests/auth.rs +++ b/opensearch/tests/auth.rs @@ -45,8 +45,9 @@ async fn basic_auth_header() -> anyhow::Result<()> { write!(encoder, "username:password").unwrap(); } - assert_eq!( - req.headers()["authorization"], + assert_header_eq!( + req, + "authorization", String::from_utf8(header_value).unwrap() ); http::Response::default() @@ -70,8 +71,9 @@ async fn api_key_header() -> anyhow::Result<()> { write!(encoder, "id:api_key").unwrap(); } - assert_eq!( - req.headers()["authorization"], + assert_header_eq!( + req, + "authorization", String::from_utf8(header_value).unwrap() ); http::Response::default() @@ -89,7 +91,7 @@ async fn api_key_header() -> anyhow::Result<()> { #[tokio::test] async fn bearer_header() -> anyhow::Result<()> { let server = server::http(move |req| async move { - assert_eq!(req.headers()["authorization"], "Bearer access_token"); + assert_header_eq!(req, "authorization", "Bearer access_token"); http::Response::default() }); diff --git a/opensearch/tests/aws_auth.rs b/opensearch/tests/aws_auth.rs index 72d6d13b..cfa6b4b9 100644 --- a/opensearch/tests/aws_auth.rs +++ b/opensearch/tests/aws_auth.rs @@ -12,15 +12,83 @@ #![cfg(feature = "aws-auth")] pub mod common; -use common::*; -use opensearch::OpenSearch; -use regex::Regex; - use aws_config::SdkConfig; use aws_credential_types::provider::SharedCredentialsProvider; -use aws_credential_types::Credentials; +use aws_credential_types::Credentials as AwsCredentials; +use aws_smithy_async::time::StaticTimeSource; use aws_types::region::Region; +use common::*; +use http::header::HOST; +use opensearch::{auth::Credentials, indices::IndicesCreateParts, OpenSearch}; +use regex::Regex; +use serde_json::json; use std::convert::TryInto; +use test_case::test_case; + +#[test_case("es", "10c9be415f4b9f15b12abbb16bd3e3730b2e6c76e0cf40db75d08a44ed04a3a1"; "when service name is es")] +#[test_case("aoss", "34903aef90423aa7dd60575d3d45316c6ef2d57bbe564a152b41bf8f5917abf6"; "when service name is aoss")] +#[test_case("arbitrary", "156e65c504ea2b2722a481b7515062e7692d27217b477828854e715f507e6a36"; "when service name is arbitrary")] +#[tokio::test] +async fn aws_auth_signs_correctly( + service_name: &str, + expected_signature: &str, +) -> anyhow::Result<()> { + tracing_init(); + + let (server, mut rx) = server::capturing_http(); + + let aws_creds = AwsCredentials::new("test-access-key", "test-secret-key", None, None, "test"); + let region = Region::new("ap-southeast-2"); + let time_source = StaticTimeSource::from_secs(1673626117); // 2023-01-13 16:08:37 +0000 + let host = format!("aaabbbcccddd111222333.ap-southeast-2.{service_name}.amazonaws.com"); + + let transport_builder = client::create_builder(&format!("http://{}", server.addr())) + .auth(Credentials::AwsSigV4( + SharedCredentialsProvider::new(aws_creds), + region, + )) + .service_name(service_name) + .sigv4_time_source(time_source.into()) + .header(HOST, host.parse().unwrap()); + let client = client::create(transport_builder); + + let _ = client + .indices() + .create(IndicesCreateParts::Index("sample-index1")) + .body(json!({ + "aliases": { + "sample-alias1": {} + }, + "mappings": { + "properties": { + "age": { + "type": "integer" + } + } + }, + "settings": { + "index.number_of_replicas": 1, + "index.number_of_shards": 2 + } + })) + .send() + .await?; + + let sent_req = rx.recv().await.expect("should have sent a request"); + + assert_header_eq!(sent_req, "accept", "application/json"); + assert_header_eq!(sent_req, "content-type", "application/json"); + assert_header_eq!(sent_req, "host", host); + assert_header_eq!(sent_req, "x-amz-date", "20230113T160837Z"); + assert_header_eq!( + sent_req, + "x-amz-content-sha256", + "4c770eaed349122a28302ff73d34437cad600acda5a9dd373efc7da2910f8564" + ); + assert_header_eq!(sent_req, "authorization", format!("AWS4-HMAC-SHA256 Credential=test-access-key/20230113/ap-southeast-2/{service_name}/aws4_request, SignedHeaders=accept;content-type;host;x-amz-content-sha256;x-amz-date, Signature={expected_signature}")); + + Ok(()) +} #[tokio::test] async fn aws_auth_get() -> anyhow::Result<()> { @@ -32,9 +100,9 @@ async fn aws_auth_get() -> anyhow::Result<()> { "{}", authorization_header ); - let amz_content_sha256_header = req.headers()["x-amz-content-sha256"].to_str().unwrap(); - assert_eq!( - amz_content_sha256_header, + assert_header_eq!( + req, + "x-amz-content-sha256", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" ); // SHA of empty string http::Response::default() @@ -49,9 +117,9 @@ async fn aws_auth_get() -> anyhow::Result<()> { #[tokio::test] async fn aws_auth_post() -> anyhow::Result<()> { let server = server::http(move |req| async move { - let amz_content_sha256_header = req.headers()["x-amz-content-sha256"].to_str().unwrap(); - assert_eq!( - amz_content_sha256_header, + assert_header_eq!( + req, + "x-amz-content-sha256", "f3a842f988a653a734ebe4e57c45f19293a002241a72f0b3abbff71e4f5297b9" ); // SHA of the JSON http::Response::default() @@ -73,7 +141,7 @@ async fn aws_auth_post() -> anyhow::Result<()> { } fn create_aws_client(addr: &str) -> anyhow::Result { - let aws_creds = Credentials::new("id", "secret", None, None, "token"); + let aws_creds = AwsCredentials::new("id", "secret", None, None, "token"); let creds_provider = SharedCredentialsProvider::new(aws_creds); let aws_config = SdkConfig::builder() .credentials_provider(creds_provider) diff --git a/opensearch/tests/client.rs b/opensearch/tests/client.rs index 9cf68dcb..48fa6e7f 100644 --- a/opensearch/tests/client.rs +++ b/opensearch/tests/client.rs @@ -52,9 +52,9 @@ use std::time::Duration; #[tokio::test] async fn default_user_agent_content_type_accept_headers() -> anyhow::Result<()> { let server = server::http(move |req| async move { - assert_eq!(req.headers()["user-agent"], DEFAULT_USER_AGENT); - assert_eq!(req.headers()["content-type"], "application/json"); - assert_eq!(req.headers()["accept"], "application/json"); + assert_header_eq!(req, "user-agent", DEFAULT_USER_AGENT); + assert_header_eq!(req, "content-type", "application/json"); + assert_header_eq!(req, "accept", "application/json"); http::Response::default() }); @@ -67,7 +67,7 @@ async fn default_user_agent_content_type_accept_headers() -> anyhow::Result<()> #[tokio::test] async fn default_header() -> anyhow::Result<()> { let server = server::http(move |req| async move { - assert_eq!(req.headers()["x-opaque-id"], "foo"); + assert_header_eq!(req, "x-opaque-id", "foo"); http::Response::default() }); @@ -85,7 +85,7 @@ async fn default_header() -> anyhow::Result<()> { #[tokio::test] async fn override_default_header() -> anyhow::Result<()> { let server = server::http(move |req| async move { - assert_eq!(req.headers()["x-opaque-id"], "bar"); + assert_header_eq!(req, "x-opaque-id", "bar"); http::Response::default() }); @@ -110,7 +110,7 @@ async fn override_default_header() -> anyhow::Result<()> { #[tokio::test] async fn x_opaque_id_header() -> anyhow::Result<()> { let server = server::http(move |req| async move { - assert_eq!(req.headers()["x-opaque-id"], "foo"); + assert_header_eq!(req, "x-opaque-id", "foo"); http::Response::default() }); diff --git a/opensearch/tests/common/mod.rs b/opensearch/tests/common/mod.rs index dd1dadcf..93dacd5c 100644 --- a/opensearch/tests/common/mod.rs +++ b/opensearch/tests/common/mod.rs @@ -28,8 +28,23 @@ * GitHub history for details. */ -pub mod client; -pub mod server; +#![allow(unused)] -#[allow(unused)] -pub static DEFAULT_USER_AGENT: &str = concat!("opensearch-rs/", env!("CARGO_PKG_VERSION")); +pub(crate) mod client; +pub(crate) mod server; + +pub(crate) static DEFAULT_USER_AGENT: &str = concat!("opensearch-rs/", env!("CARGO_PKG_VERSION")); + +macro_rules! assert_header_eq { + ($req:expr, $header:expr, $value:expr) => { + assert_eq!($req.headers()[$header], $value); + }; +} + +static TRACING: std::sync::Once = std::sync::Once::new(); + +pub(crate) fn tracing_init() { + TRACING.call_once(|| tracing_subscriber::fmt::init()) +} + +pub(crate) use assert_header_eq; diff --git a/opensearch/tests/common/server.rs b/opensearch/tests/common/server.rs index f3dcd8af..ec9e798a 100644 --- a/opensearch/tests/common/server.rs +++ b/opensearch/tests/common/server.rs @@ -36,7 +36,12 @@ use std::{ convert::Infallible, future::Future, net, sync::mpsc as std_mpsc, thread, time::Duration, }; -use tokio::sync::oneshot; +use http::Request; +use hyper::Body; +use tokio::sync::{ + mpsc::{unbounded_channel, UnboundedReceiver}, + oneshot, +}; pub use http::Response; use tokio::runtime; @@ -122,3 +127,15 @@ where .join() .unwrap() } + +pub fn capturing_http() -> (Server, UnboundedReceiver>) { + let (tx, rx) = unbounded_channel(); + let server = http(move |req| { + let tx = tx.clone(); + async move { + tx.send(req).unwrap(); + http::Response::default() + } + }); + (server, rx) +}