From 7004d293407e85c85657a4c650386c305b9e8550 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Wed, 13 Dec 2023 08:47:34 +1300 Subject: [PATCH] Refactor mock HTTP server tests (#227) (#231) * Refactor mock http server tests Signed-off-by: Thomas Farr * de-dup code Signed-off-by: Thomas Farr --------- Signed-off-by: Thomas Farr (cherry picked from commit 1035d4e60cf712488834bdfa526680f32b60efdb) Co-authored-by: Thomas Farr --- .github/workflows/test.yml | 2 +- Makefile.toml | 10 +- api_generator/src/rest_spec/mod.rs | 2 +- opensearch/Cargo.toml | 1 - opensearch/src/auth.rs | 4 +- opensearch/src/models/mod.rs | 13 ++ opensearch/tests/auth.rs | 85 +++----- opensearch/tests/aws_auth.rs | 137 ++++++------ opensearch/tests/cert.rs | 49 ++--- opensearch/tests/client.rs | 172 +++++++-------- opensearch/tests/common/client.rs | 88 +++++--- opensearch/tests/common/mod.rs | 20 +- opensearch/tests/common/server.rs | 266 +++++++++++++++--------- opensearch/tests/error.rs | 6 +- yaml_test_runner/src/generator.rs | 4 +- yaml_test_runner/src/github.rs | 2 +- yaml_test_runner/tests/common/macros.rs | 4 +- 17 files changed, 459 insertions(+), 406 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 82ecd08b..a9bbf290 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -36,7 +36,7 @@ jobs: version: 2.8.0 secured: true - - name: Run Tests (${{ matrix.test-args }}) + - name: Run Tests working-directory: client run: cargo make test ${{ matrix.test-args }} env: diff --git a/Makefile.toml b/Makefile.toml index 95b496eb..c01a1dc2 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -20,18 +20,14 @@ OPENSEARCH_URL = { value = "${OPENSEARCH_PROTOCOL}://localhost:9200", condition category = "OpenSearch" description = "Generates SSL certificates used for integration tests" command = "bash" -args =["./.ci/generate-certs.sh"] +args = ["./.ci/generate-certs.sh"] [tasks.run-opensearch] category = "OpenSearch" private = true condition = { env_set = [ "STACK_VERSION"], env_false = ["CARGO_MAKE_CI"] } - -[tasks.run-opensearch.linux] -command = "./.ci/run-opensearch.sh" - -[tasks.run-opensearch.mac] -command = "./.ci/run-opensearch.sh" +command = "bash" +args = ["./.ci/run-opensearch.sh"] [tasks.run-opensearch.windows] script_runner = "cmd" diff --git a/api_generator/src/rest_spec/mod.rs b/api_generator/src/rest_spec/mod.rs index 6903a1b2..9cb037a2 100644 --- a/api_generator/src/rest_spec/mod.rs +++ b/api_generator/src/rest_spec/mod.rs @@ -41,7 +41,7 @@ pub fn download_specs(branch: &str, download_dir: &Path) -> anyhow::Result<()> { .build() .unwrap(); - let response = client.get(&url).send()?; + let response = client.get(url).send()?; let tar = GzDecoder::new(response); let mut archive = Archive::new(tar); diff --git a/opensearch/Cargo.toml b/opensearch/Cargo.toml index 5090436c..cc2a87fa 100644 --- a/opensearch/Cargo.toml +++ b/opensearch/Cargo.toml @@ -55,7 +55,6 @@ futures = "0.3.1" http-body-util = "0.1.0" hyper = { version = "1", features = ["full"] } hyper-util = { version = "0.1", features = ["full"] } -regex="1.4" sysinfo = "0.29.0" test-case = "3" textwrap = "0.16" diff --git a/opensearch/src/auth.rs b/opensearch/src/auth.rs index 787917ba..0affb539 100644 --- a/opensearch/src/auth.rs +++ b/opensearch/src/auth.rs @@ -90,7 +90,7 @@ impl From for Credentials { } } -#[cfg(any(feature = "aws-auth"))] +#[cfg(feature = "aws-auth")] impl std::convert::TryFrom<&aws_types::SdkConfig> for Credentials { type Error = super::Error; @@ -107,7 +107,7 @@ impl std::convert::TryFrom<&aws_types::SdkConfig> for Credentials { } } -#[cfg(any(feature = "aws-auth"))] +#[cfg(feature = "aws-auth")] impl std::convert::TryFrom for Credentials { type Error = super::Error; diff --git a/opensearch/src/models/mod.rs b/opensearch/src/models/mod.rs index 0fb2d8a8..88c1987a 100644 --- a/opensearch/src/models/mod.rs +++ b/opensearch/src/models/mod.rs @@ -1,3 +1,16 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +#![allow(unused)] + use serde::Deserialize; #[derive(Deserialize, Debug)] diff --git a/opensearch/tests/auth.rs b/opensearch/tests/auth.rs index 43ea1896..943f5a0c 100644 --- a/opensearch/tests/auth.rs +++ b/opensearch/tests/auth.rs @@ -29,77 +29,58 @@ */ pub mod common; -use common::*; +use crate::common::server::MockServer; use opensearch::auth::Credentials; -use base64::{prelude::BASE64_STANDARD, write::EncoderWriter as Base64Encoder}; -use std::io::Write; - #[tokio::test] async fn basic_auth_header() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - let mut header_value = b"Basic ".to_vec(); - { - let mut encoder = Base64Encoder::new(&mut header_value, &BASE64_STANDARD); - write!(encoder, "username:password").unwrap(); - } - - assert_header_eq!( - req, - "authorization", - String::from_utf8(header_value).unwrap() - ); - server::empty_response() - }); - - let builder = client::create_builder(format!("http://{}", server.addr()).as_ref()) - .auth(Credentials::Basic("username".into(), "password".into())); - - let client = client::create(builder); - let _response = client.ping().send().await?; + let mut server = MockServer::start()?; + + let client = + server.client_with(|b| b.auth(Credentials::Basic("username".into(), "password".into()))); + + let _ = client.ping().send().await?; + + let request = server.received_request().await?; + + assert_eq!( + request.header("authorization"), + Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ=") + ); Ok(()) } #[tokio::test] async fn api_key_header() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - let mut header_value = b"ApiKey ".to_vec(); - { - let mut encoder = Base64Encoder::new(&mut header_value, &BASE64_STANDARD); - write!(encoder, "id:api_key").unwrap(); - } - - assert_header_eq!( - req, - "authorization", - String::from_utf8(header_value).unwrap() - ); - server::empty_response() - }); - - let builder = client::create_builder(format!("http://{}", server.addr()).as_ref()) - .auth(Credentials::ApiKey("id".into(), "api_key".into())); - - let client = client::create(builder); - let _response = client.ping().send().await?; + let mut server = MockServer::start()?; + + let client = server.client_with(|b| b.auth(Credentials::ApiKey("id".into(), "api_key".into()))); + + let _ = client.ping().send().await?; + + let request = server.received_request().await?; + + assert_eq!( + request.header("authorization"), + Some("ApiKey aWQ6YXBpX2tleQ==") + ); Ok(()) } #[tokio::test] async fn bearer_header() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - assert_header_eq!(req, "authorization", "Bearer access_token"); - server::empty_response() - }); + let mut server = MockServer::start()?; + + let client = server.client_with(|b| b.auth(Credentials::Bearer("access_token".into()))); + + let _ = client.ping().send().await?; - let builder = client::create_builder(format!("http://{}", server.addr()).as_ref()) - .auth(Credentials::Bearer("access_token".into())); + let request = server.received_request().await?; - let client = client::create(builder); - let _response = client.ping().send().await?; + assert_eq!(request.header("authorization"), Some("Bearer access_token")); Ok(()) } diff --git a/opensearch/tests/aws_auth.rs b/opensearch/tests/aws_auth.rs index 200d421f..9c7e1440 100644 --- a/opensearch/tests/aws_auth.rs +++ b/opensearch/tests/aws_auth.rs @@ -12,19 +12,34 @@ #![cfg(feature = "aws-auth")] pub mod common; -use aws_config::SdkConfig; -use aws_credential_types::provider::SharedCredentialsProvider; -use aws_credential_types::Credentials as AwsCredentials; +use aws_credential_types::{provider::SharedCredentialsProvider, Credentials as AwsCredentials}; use aws_smithy_async::time::StaticTimeSource; use aws_types::region::Region; -use common::*; -use opensearch::{auth::Credentials, indices::IndicesCreateParts, OpenSearch}; -use regex::Regex; -use reqwest::header::HOST; +use common::{server::MockServer, tracing_init}; +use opensearch::{ + http::{headers::HOST, transport::TransportBuilder}, + indices::IndicesCreateParts, +}; +use reqwest::header::HeaderValue; use serde_json::json; -use std::convert::TryInto; use test_case::test_case; +fn sigv4_config(transport: TransportBuilder, service_name: &str) -> TransportBuilder { + let aws_creds = AwsCredentials::new("test-access-key", "test-secret-key", None, None, "test"); + let region = Region::new("ap-southeast-2"); + let time_source = StaticTimeSource::from_secs(1673626117); // 2023-01-13 16:08:37 +0000 + + transport + .auth(opensearch::auth::Credentials::AwsSigV4( + SharedCredentialsProvider::new(aws_creds), + region, + )) + .service_name(service_name) + .sigv4_time_source(time_source.into()) +} + +const LOCALHOST: HeaderValue = HeaderValue::from_static("localhost"); + #[test_case("es", "10c9be415f4b9f15b12abbb16bd3e3730b2e6c76e0cf40db75d08a44ed04a3a1"; "when service name is es")] #[test_case("aoss", "34903aef90423aa7dd60575d3d45316c6ef2d57bbe564a152b41bf8f5917abf6"; "when service name is aoss")] #[test_case("arbitrary", "156e65c504ea2b2722a481b7515062e7692d27217b477828854e715f507e6a36"; "when service name is arbitrary")] @@ -35,22 +50,12 @@ async fn aws_auth_signs_correctly( ) -> anyhow::Result<()> { tracing_init(); - let (server, mut rx) = server::capturing_http(); + let mut server = MockServer::start()?; - let aws_creds = AwsCredentials::new("test-access-key", "test-secret-key", None, None, "test"); - let region = Region::new("ap-southeast-2"); - let time_source = StaticTimeSource::from_secs(1673626117); // 2023-01-13 16:08:37 +0000 let host = format!("aaabbbcccddd111222333.ap-southeast-2.{service_name}.amazonaws.com"); - let transport_builder = client::create_builder(&format!("http://{}", server.addr())) - .auth(Credentials::AwsSigV4( - SharedCredentialsProvider::new(aws_creds), - region, - )) - .service_name(service_name) - .sigv4_time_source(time_source.into()) - .header(HOST, host.parse().unwrap()); - let client = client::create(transport_builder); + let client = + server.client_with(|b| sigv4_config(b, service_name).header(HOST, host.parse().unwrap())); let _ = client .indices() @@ -74,59 +79,51 @@ async fn aws_auth_signs_correctly( .send() .await?; - let sent_req = rx.recv().await.expect("should have sent a request"); + let sent_req = server.received_request().await?; - assert_header_eq!(sent_req, "accept", "application/json"); - assert_header_eq!(sent_req, "content-type", "application/json"); - assert_header_eq!(sent_req, "host", host); - assert_header_eq!(sent_req, "x-amz-date", "20230113T160837Z"); - assert_header_eq!( - sent_req, - "x-amz-content-sha256", - "4c770eaed349122a28302ff73d34437cad600acda5a9dd373efc7da2910f8564" + assert_eq!(sent_req.header("accept"), Some("application/json")); + assert_eq!(sent_req.header("content-type"), Some("application/json")); + assert_eq!(sent_req.header("host"), Some(host.as_str())); + assert_eq!(sent_req.header("x-amz-date"), Some("20230113T160837Z")); + assert_eq!( + sent_req.header("x-amz-content-sha256"), + Some("4c770eaed349122a28302ff73d34437cad600acda5a9dd373efc7da2910f8564") ); - assert_header_eq!(sent_req, "authorization", format!("AWS4-HMAC-SHA256 Credential=test-access-key/20230113/ap-southeast-2/{service_name}/aws4_request, SignedHeaders=accept;content-type;host;x-amz-content-sha256;x-amz-date, Signature={expected_signature}")); + assert_eq!(sent_req.header("authorization"), Some(format!("AWS4-HMAC-SHA256 Credential=test-access-key/20230113/ap-southeast-2/{service_name}/aws4_request, SignedHeaders=accept;content-type;host;x-amz-content-sha256;x-amz-date, Signature={expected_signature}").as_str())); Ok(()) } #[tokio::test] async fn aws_auth_get() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - let authorization_header = req.headers()["authorization"].to_str().unwrap(); - let re = Regex::new(r"^AWS4-HMAC-SHA256 Credential=id/\d*/us-west-1/custom/aws4_request, SignedHeaders=accept;content-type;host;x-amz-content-sha256;x-amz-date, Signature=[a-f,0-9].*$").unwrap(); - assert!( - re.is_match(authorization_header), - "{}", - authorization_header - ); - assert_header_eq!( - req, - "x-amz-content-sha256", - "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" - ); // SHA of empty string - server::empty_response() - }); - - let client = create_aws_client(format!("http://{}", server.addr()).as_ref())?; - let _response = client.ping().send().await?; + tracing_init(); + + let mut server = MockServer::start()?; + + let client = server.client_with(|b| sigv4_config(b, "custom").header(HOST, LOCALHOST)); + + let _ = client.ping().send().await?; + + let sent_req = server.received_request().await?; + + assert_eq!(sent_req.header("authorization"), Some("AWS4-HMAC-SHA256 Credential=test-access-key/20230113/ap-southeast-2/custom/aws4_request, SignedHeaders=accept;content-type;host;x-amz-content-sha256;x-amz-date, Signature=e5aa6e5d9e1b86b86ed31fbb10dd62b4e93423b77830f8189701421d3e9f65bd")); + assert_eq!( + sent_req.header("x-amz-content-sha256"), + Some("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855") + ); // SHA of zero-length body Ok(()) } #[tokio::test] async fn aws_auth_post() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - assert_header_eq!( - req, - "x-amz-content-sha256", - "f3a842f988a653a734ebe4e57c45f19293a002241a72f0b3abbff71e4f5297b9" - ); // SHA of the JSON - server::empty_response() - }); - - let client = create_aws_client(format!("http://{}", server.addr()).as_ref())?; - client + tracing_init(); + + let mut server = MockServer::start()?; + + let client = server.client_with(|b| sigv4_config(b, "custom").header(HOST, LOCALHOST)); + + let _ = client .index(opensearch::IndexParts::Index("movies")) .body(serde_json::json!({ "title": "Moneyball", @@ -137,18 +134,12 @@ async fn aws_auth_post() -> anyhow::Result<()> { .send() .await?; - Ok(()) -} + let sent_req = server.received_request().await?; -fn create_aws_client(addr: &str) -> anyhow::Result { - let aws_creds = AwsCredentials::new("id", "secret", None, None, "token"); - let creds_provider = SharedCredentialsProvider::new(aws_creds); - let aws_config = SdkConfig::builder() - .credentials_provider(creds_provider) - .region(Region::new("us-west-1")) - .build(); - let builder = client::create_builder(addr) - .auth(aws_config.clone().try_into()?) - .service_name("custom"); - Ok(client::create(builder)) + assert_eq!( + sent_req.header("x-amz-content-sha256"), + Some("f3a842f988a653a734ebe4e57c45f19293a002241a72f0b3abbff71e4f5297b9") + ); // SHA of the JSON + + Ok(()) } diff --git a/opensearch/tests/cert.rs b/opensearch/tests/cert.rs index df5f6493..1b85d29b 100644 --- a/opensearch/tests/cert.rs +++ b/opensearch/tests/cert.rs @@ -50,8 +50,7 @@ fn expected_error_message() -> &'static str { #[tokio::test] #[cfg(feature = "native-tls")] async fn default_certificate_validation() -> anyhow::Result<()> { - let builder = client::create_default_builder().cert_validation(CertificateValidation::Default); - let client = client::create(builder); + let client = client::create_with(|b| b.cert_validation(CertificateValidation::Default)); match client.ping().send().await { Ok(response) => Err(anyhow!( @@ -77,8 +76,7 @@ async fn default_certificate_validation() -> anyhow::Result<()> { #[tokio::test] #[cfg(all(feature = "rustls-tls", not(feature = "native-tls")))] async fn default_certificate_validation_rustls_tls() -> anyhow::Result<()> { - let builder = client::create_default_builder().cert_validation(CertificateValidation::Default); - let client = client::create(builder); + let client = client::create_with(|b| b.cert_validation(CertificateValidation::Default)); match client.ping().send().await { Ok(response) => Err(anyhow!( @@ -103,8 +101,7 @@ async fn default_certificate_validation_rustls_tls() -> anyhow::Result<()> { /// Allows any certificate through #[tokio::test] async fn none_certificate_validation() -> anyhow::Result<()> { - let builder = client::create_default_builder().cert_validation(CertificateValidation::None); - let client = client::create(builder); + let client = client::create_with(|b| b.cert_validation(CertificateValidation::None)); let _response = client.ping().send().await?; Ok(()) } @@ -118,9 +115,7 @@ async fn none_certificate_validation() -> anyhow::Result<()> { ))] async fn full_certificate_ca_validation() -> anyhow::Result<()> { let cert = Certificate::from_pem(CA_CERT)?; - let builder = - client::create_default_builder().cert_validation(CertificateValidation::Full(cert)); - let client = client::create(builder); + let client = client::create_with(|b| b.cert_validation(CertificateValidation::Full(cert))); let _response = client.ping().send().await?; Ok(()) } @@ -135,9 +130,7 @@ async fn full_certificate_ca_chain_validation() -> anyhow::Result<()> { let mut cert = Certificate::from_pem(CA_CHAIN_CERT)?; cert.append(Certificate::from_pem(CA_CERT)?); assert_eq!(cert.len(), 3, "expected three certificates in CA chain"); - let builder = - client::create_default_builder().cert_validation(CertificateValidation::Full(cert)); - let client = client::create(builder); + let client = client::create_with(|b| b.cert_validation(CertificateValidation::Full(cert))); let _response = client.ping().send().await?; Ok(()) } @@ -147,9 +140,7 @@ async fn full_certificate_ca_chain_validation() -> anyhow::Result<()> { #[cfg(all(windows, feature = "native-tls"))] async fn full_certificate_validation() -> anyhow::Result<()> { let cert = Certificate::from_pem(ESNODE_CERT)?; - let builder = - client::create_default_builder().cert_validation(CertificateValidation::Full(cert)); - let client = client::create(builder); + let client = client::create_with(|b| b.cert_validation(CertificateValidation::Full(cert))); let _response = client.ping().send().await?; Ok(()) } @@ -163,9 +154,7 @@ async fn full_certificate_validation_rustls_tls() -> anyhow::Result<()> { chain.extend(ESNODE_CERT); let cert = Certificate::from_pem(chain.as_slice())?; - let builder = - client::create_default_builder().cert_validation(CertificateValidation::Full(cert)); - let client = client::create(builder); + let client = client::create_with(|b| b.cert_validation(CertificateValidation::Full(cert))); let _response = client.ping().send().await?; Ok(()) } @@ -176,9 +165,7 @@ async fn full_certificate_validation_rustls_tls() -> anyhow::Result<()> { #[cfg(all(unix, feature = "native-tls"))] async fn full_certificate_validation() -> anyhow::Result<()> { let cert = Certificate::from_pem(ESNODE_CERT)?; - let builder = - client::create_default_builder().cert_validation(CertificateValidation::Full(cert)); - let client = client::create(builder); + let client = client::create_with(|b| b.cert_validation(CertificateValidation::Full(cert))); match client.ping().send().await { Ok(response) => Err(anyhow!( @@ -205,9 +192,8 @@ async fn full_certificate_validation() -> anyhow::Result<()> { #[cfg(all(windows, feature = "native-tls"))] async fn certificate_certificate_validation() -> anyhow::Result<()> { let cert = Certificate::from_pem(ESNODE_CERT)?; - let builder = - client::create_default_builder().cert_validation(CertificateValidation::Certificate(cert)); - let client = client::create(builder); + let client = + client::create_with(|b| b.cert_validation(CertificateValidation::Certificate(cert))); let _response = client.ping().send().await?; Ok(()) } @@ -218,9 +204,8 @@ async fn certificate_certificate_validation() -> anyhow::Result<()> { #[cfg(all(unix, feature = "native-tls"))] async fn certificate_certificate_validation() -> anyhow::Result<()> { let cert = Certificate::from_pem(ESNODE_CERT)?; - let builder = - client::create_default_builder().cert_validation(CertificateValidation::Certificate(cert)); - let client = client::create(builder); + let client = + client::create_with(|b| b.cert_validation(CertificateValidation::Certificate(cert))); match client.ping().send().await { Ok(response) => Err(anyhow!( @@ -248,9 +233,8 @@ async fn certificate_certificate_validation() -> anyhow::Result<()> { #[cfg(all(feature = "native-tls", not(target_os = "macos")))] async fn certificate_certificate_ca_validation() -> anyhow::Result<()> { let cert = Certificate::from_pem(CA_CERT)?; - let builder = - client::create_default_builder().cert_validation(CertificateValidation::Certificate(cert)); - let client = client::create(builder); + let client = + client::create_with(|b| b.cert_validation(CertificateValidation::Certificate(cert))); let _response = client.ping().send().await?; Ok(()) } @@ -260,9 +244,8 @@ async fn certificate_certificate_ca_validation() -> anyhow::Result<()> { #[cfg(feature = "native-tls")] async fn fail_certificate_certificate_validation() -> anyhow::Result<()> { let cert = Certificate::from_pem(ESNODE_NO_SAN_CERT)?; - let builder = - client::create_default_builder().cert_validation(CertificateValidation::Certificate(cert)); - let client = client::create(builder); + let client = + client::create_with(|b| b.cert_validation(CertificateValidation::Certificate(cert))); match client.ping().send().await { Ok(response) => Err(anyhow!( diff --git a/opensearch/tests/client.rs b/opensearch/tests/client.rs index f4358f67..419d25ec 100644 --- a/opensearch/tests/client.rs +++ b/opensearch/tests/client.rs @@ -30,7 +30,10 @@ pub mod common; use common::*; +use hyper::Method; +use crate::common::{client::index_documents, server::MockServer}; +use bytes::Bytes; use opensearch::{ http::{ headers::{ @@ -42,60 +45,56 @@ use opensearch::{ params::TrackTotalHits, SearchParts, }; - -use crate::common::client::index_documents; -use bytes::Bytes; -use hyper::Method; use serde_json::{json, Value}; use std::time::Duration; #[tokio::test] async fn default_user_agent_content_type_accept_headers() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - assert_header_eq!(req, "user-agent", DEFAULT_USER_AGENT); - assert_header_eq!(req, "content-type", "application/json"); - assert_header_eq!(req, "accept", "application/json"); - server::empty_response() - }); + let mut server = MockServer::start()?; + + let _ = server.client().ping().send().await?; + + let request = server.received_request().await?; - let client = client::create_for_url(format!("http://{}", server.addr()).as_ref()); - let _response = client.ping().send().await?; + assert_eq!(request.header("user-agent"), Some(DEFAULT_USER_AGENT)); + assert_eq!(request.header("content-type"), Some(DEFAULT_CONTENT_TYPE)); + assert_eq!(request.header("accept"), Some(DEFAULT_ACCEPT)); Ok(()) } #[tokio::test] async fn default_header() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - assert_header_eq!(req, "x-opaque-id", "foo"); - server::empty_response() + let mut server = MockServer::start()?; + + let client = server.client_with(|b| { + b.header( + HeaderName::from_static(X_OPAQUE_ID), + HeaderValue::from_static("foo"), + ) }); - let builder = client::create_builder(format!("http://{}", server.addr()).as_ref()).header( - HeaderName::from_static(X_OPAQUE_ID), - HeaderValue::from_static("foo"), - ); + let _ = client.ping().send().await?; - let client = client::create(builder); - let _response = client.ping().send().await?; + let request = server.received_request().await?; + + assert_eq!(request.header("x-opaque-id"), Some("foo")); Ok(()) } #[tokio::test] async fn override_default_header() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - assert_header_eq!(req, "x-opaque-id", "bar"); - server::empty_response() - }); + let mut server = MockServer::start()?; - let builder = client::create_builder(format!("http://{}", server.addr()).as_ref()).header( - HeaderName::from_static(X_OPAQUE_ID), - HeaderValue::from_static("foo"), - ); + let client = server.client_with(|b| { + b.header( + HeaderName::from_static(X_OPAQUE_ID), + HeaderValue::from_static("foo"), + ) + }); - let client = client::create(builder); - let _response = client + let _ = client .ping() .header( HeaderName::from_static(X_OPAQUE_ID), @@ -104,18 +103,19 @@ async fn override_default_header() -> anyhow::Result<()> { .send() .await?; + let request = server.received_request().await?; + + assert_eq!(request.header("x-opaque-id"), Some("bar")); + Ok(()) } #[tokio::test] async fn x_opaque_id_header() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - assert_header_eq!(req, "x-opaque-id", "foo"); - server::empty_response() - }); + let mut server = MockServer::start()?; - let client = client::create_for_url(format!("http://{}", server.addr()).as_ref()); - let _response = client + let _ = server + .client() .ping() .header( HeaderName::from_static(X_OPAQUE_ID), @@ -124,39 +124,39 @@ async fn x_opaque_id_header() -> anyhow::Result<()> { .send() .await?; + let request = server.received_request().await?; + + assert_eq!(request.header("x-opaque-id"), Some("foo")); + Ok(()) } #[tokio::test] -async fn uses_global_request_timeout() { - let server = server::http(move |_| async move { - std::thread::sleep(Duration::from_secs(1)); - server::empty_response() - }); +async fn uses_global_request_timeout() -> anyhow::Result<()> { + let server = MockServer::builder() + .response_delay(Duration::from_secs(1)) + .start()?; - let builder = client::create_builder(format!("http://{}", server.addr()).as_ref()) - .timeout(std::time::Duration::from_millis(500)); + let client = server.client_with(|b| b.timeout(Duration::from_millis(500))); - let client = client::create(builder); let response = client.ping().send().await; match response { Ok(_) => panic!("Expected timeout error, but response received"), Err(e) => assert!(e.is_timeout(), "Expected timeout error, but was {:?}", e), } + + Ok(()) } #[tokio::test] -async fn uses_call_request_timeout() { - let server = server::http(move |_| async move { - std::thread::sleep(Duration::from_secs(1)); - server::empty_response() - }); +async fn uses_call_request_timeout() -> anyhow::Result<()> { + let server = MockServer::builder() + .response_delay(Duration::from_secs(1)) + .start()?; - let builder = client::create_builder(format!("http://{}", server.addr()).as_ref()) - .timeout(std::time::Duration::from_secs(2)); + let client = server.client_with(|b| b.timeout(Duration::from_secs(2))); - let client = client::create(builder); let response = client .ping() .request_timeout(Duration::from_millis(500)) @@ -167,34 +167,36 @@ async fn uses_call_request_timeout() { Ok(_) => panic!("Expected timeout error, but response received"), Err(e) => assert!(e.is_timeout(), "Expected timeout error, but was {:?}", e), } + + Ok(()) } #[tokio::test] -async fn call_request_timeout_supersedes_global_timeout() { - let server = server::http(move |_| async move { - std::thread::sleep(Duration::from_secs(1)); - server::empty_response() - }); +async fn call_request_timeout_supersedes_global_timeout() -> anyhow::Result<()> { + let server = MockServer::builder() + .response_delay(Duration::from_secs(1)) + .start()?; - let builder = client::create_builder(format!("http://{}", server.addr()).as_ref()) - .timeout(std::time::Duration::from_millis(500)); + let client = server.client_with(|b| b.timeout(Duration::from_millis(500))); - let client = client::create(builder); let response = client .ping() .request_timeout(Duration::from_secs(2)) .send() .await; - match response { - Ok(_) => (), - Err(e) => assert!(e.is_timeout(), "Did not expect error, but was {:?}", e), - } + assert!( + response.is_ok(), + "Expected response, but was: {:?}", + response + ); + + Ok(()) } #[tokio::test] async fn deprecation_warning_headers() -> anyhow::Result<()> { - let client = client::create_default(); + let client = client::create(); let _ = index_documents(&client).await?; let response = client .search(SearchParts::None) @@ -239,18 +241,10 @@ async fn deprecation_warning_headers() -> anyhow::Result<()> { #[tokio::test] async fn serialize_querystring() -> anyhow::Result<()> { - let server = server::http(move |req| async move { - assert_eq!(req.method(), Method::GET); - assert_eq!(req.uri().path(), "/_search"); - assert_eq!( - req.uri().query(), - Some("filter_path=took%2C_shards&pretty=true&q=title%3AOpenSearch&track_total_hits=100000") - ); - server::empty_response() - }); + let mut server = MockServer::start()?; - let client = client::create_for_url(format!("http://{}", server.addr()).as_ref()); - let _response = client + let _ = server + .client() .search(SearchParts::None) .pretty(true) .filter_path(&["took", "_shards"]) @@ -259,12 +253,20 @@ async fn serialize_querystring() -> anyhow::Result<()> { .send() .await?; + let request = server.received_request().await?; + assert_eq!(request.method(), Method::GET); + assert_eq!(request.path(), "/_search"); + assert_eq!( + request.query(), + Some("filter_path=took%2C_shards&pretty=true&q=title%3AOpenSearch&track_total_hits=100000") + ); + Ok(()) } #[tokio::test] async fn search_with_body() -> anyhow::Result<()> { - let client = client::create_default(); + let client = client::create(); let _ = index_documents(&client).await?; let response = client .search(SearchParts::None) @@ -307,7 +309,7 @@ async fn search_with_body() -> anyhow::Result<()> { #[tokio::test] async fn search_with_no_body() -> anyhow::Result<()> { - let client = client::create_default(); + let client = client::create(); let _ = index_documents(&client).await?; let response = client .search(SearchParts::None) @@ -330,7 +332,7 @@ async fn search_with_no_body() -> anyhow::Result<()> { #[tokio::test] async fn read_response_as_bytes() -> anyhow::Result<()> { - let client = client::create_default(); + let client = client::create(); let _ = index_documents(&client).await?; let response = client .search(SearchParts::None) @@ -356,7 +358,7 @@ async fn read_response_as_bytes() -> anyhow::Result<()> { #[tokio::test] async fn cat_health_format_json() -> anyhow::Result<()> { - let client = client::create_default(); + let client = client::create(); let response = client .cat() .health() @@ -380,7 +382,7 @@ async fn cat_health_format_json() -> anyhow::Result<()> { #[tokio::test] async fn cat_health_header_json() -> anyhow::Result<()> { - let client = client::create_default(); + let client = client::create(); let response = client .cat() .health() @@ -404,7 +406,7 @@ async fn cat_health_header_json() -> anyhow::Result<()> { #[tokio::test] async fn cat_health_text() -> anyhow::Result<()> { - let client = client::create_default(); + let client = client::create(); let response = client.cat().health().pretty(true).send().await?; assert_eq!(response.status_code(), StatusCode::OK); @@ -422,7 +424,7 @@ async fn cat_health_text() -> anyhow::Result<()> { #[tokio::test] async fn clone_search_with_body() -> anyhow::Result<()> { - let client = client::create_default(); + let client = client::create(); let _ = index_documents(&client).await?; let base_request = client.search(SearchParts::None); @@ -448,7 +450,7 @@ async fn clone_search_with_body() -> anyhow::Result<()> { #[tokio::test] async fn byte_slice_body() -> anyhow::Result<()> { - let client = client::create_default(); + let client = client::create(); let body = b"{\"query\":{\"match_all\":{}}}"; let response = client diff --git a/opensearch/tests/common/client.rs b/opensearch/tests/common/client.rs index ecbde471..306132cc 100644 --- a/opensearch/tests/common/client.rs +++ b/opensearch/tests/common/client.rs @@ -35,12 +35,12 @@ use opensearch::{ http::{ response::Response, transport::{SingleNodeConnectionPool, TransportBuilder}, + StatusCode, }, indices::IndicesExistsParts, params::Refresh, BulkOperation, BulkParts, Error, OpenSearch, DEFAULT_ADDRESS, }; -use reqwest::StatusCode; use serde_json::json; use sysinfo::{ProcessRefreshKind, RefreshKind, System, SystemExt}; use url::Url; @@ -63,45 +63,75 @@ fn running_proxy() -> bool { has_fiddler } -pub fn create_default_builder() -> TransportBuilder { - create_builder(cluster_addr().as_str()) -} +pub struct TestClientBuilder(TransportBuilder); + +impl TestClientBuilder { + pub fn new() -> Self { + Self::with_url(&cluster_addr()) + } + + pub fn with_url(url: &str) -> Self { + let url = Url::parse(url).unwrap(); + let secure = url.scheme() == "https"; + let conn_pool = SingleNodeConnectionPool::new(url); + let mut builder = TransportBuilder::new(conn_pool); + + // assume if we're running with HTTPS then authentication is also enabled and disable + // certificate validation - we'll change this for tests that need to. + if secure { + builder = builder.auth(Credentials::Basic("admin".into(), "admin".into())); -pub fn create_builder(addr: &str) -> TransportBuilder { - let url = Url::parse(addr).unwrap(); - let conn_pool = SingleNodeConnectionPool::new(url.clone()); - let mut builder = TransportBuilder::new(conn_pool); - // assume if we're running with HTTPS then authentication is also enabled and disable - // certificate validation - we'll change this for tests that need to. - if url.scheme() == "https" { - builder = builder.auth(Credentials::Basic("admin".into(), "admin".into())); - - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] - { - builder = builder.cert_validation(CertificateValidation::None); + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + { + builder = builder.cert_validation(CertificateValidation::None); + } } + + Self(builder) + } + + pub fn with(mut self, configurator: impl FnOnce(TransportBuilder) -> TransportBuilder) -> Self { + self.0 = configurator(self.0); + self } - builder + pub fn build(self) -> OpenSearch { + let mut builder = self.0; + + if running_proxy() { + let proxy_url = Url::parse("http://localhost:8888").unwrap(); + builder = builder.proxy(proxy_url, None, None); + } + + let transport = builder.build().unwrap(); + OpenSearch::new(transport) + } } -pub fn create_default() -> OpenSearch { - create_for_url(cluster_addr().as_str()) +impl Default for TestClientBuilder { + fn default() -> Self { + Self::new() + } } -pub fn create_for_url(url: &str) -> OpenSearch { - let builder = create_builder(url); - create(builder) +pub fn builder() -> TestClientBuilder { + TestClientBuilder::new() } -pub fn create(mut builder: TransportBuilder) -> OpenSearch { - if running_proxy() { - let proxy_url = Url::parse("http://localhost:8888").unwrap(); - builder = builder.proxy(proxy_url, None, None); - } +pub fn builder_with_url(url: &str) -> TestClientBuilder { + TestClientBuilder::with_url(url) +} + +pub fn create() -> OpenSearch { + builder().build() +} + +pub fn create_with(configurator: impl FnOnce(TransportBuilder) -> TransportBuilder) -> OpenSearch { + builder().with(configurator).build() +} - let transport = builder.build().unwrap(); - OpenSearch::new(transport) +pub fn create_with_url(url: &str) -> OpenSearch { + builder_with_url(url).build() } /// index some documents into a posts index. If the posts index already exists, do nothing. diff --git a/opensearch/tests/common/mod.rs b/opensearch/tests/common/mod.rs index 93dacd5c..5c879006 100644 --- a/opensearch/tests/common/mod.rs +++ b/opensearch/tests/common/mod.rs @@ -28,23 +28,13 @@ * GitHub history for details. */ -#![allow(unused)] +pub mod client; +pub mod server; -pub(crate) mod client; -pub(crate) mod server; - -pub(crate) static DEFAULT_USER_AGENT: &str = concat!("opensearch-rs/", env!("CARGO_PKG_VERSION")); - -macro_rules! assert_header_eq { - ($req:expr, $header:expr, $value:expr) => { - assert_eq!($req.headers()[$header], $value); - }; -} +pub static DEFAULT_USER_AGENT: &str = concat!("opensearch-rs/", env!("CARGO_PKG_VERSION")); static TRACING: std::sync::Once = std::sync::Once::new(); -pub(crate) fn tracing_init() { - TRACING.call_once(|| tracing_subscriber::fmt::init()) +pub fn tracing_init() { + TRACING.call_once(tracing_subscriber::fmt::init) } - -pub(crate) use assert_header_eq; diff --git a/opensearch/tests/common/server.rs b/opensearch/tests/common/server.rs index 07069105..24b5cfe7 100644 --- a/opensearch/tests/common/server.rs +++ b/opensearch/tests/common/server.rs @@ -32,141 +32,211 @@ // Licensed under Apache License, Version 2.0 // https://github.com/seanmonstar/reqwest/blob/master/LICENSE-APACHE -use std::{ - convert::Infallible, - future::Future, - net::{self, SocketAddr}, - sync::mpsc as std_mpsc, - thread, - time::Duration, -}; +use std::{convert::identity, net::SocketAddr, sync::mpsc as std_mpsc, thread, time::Duration}; use bytes::Bytes; use http_body_util::Empty; use hyper::{ - body::{Body, Incoming}, - server::conn::http1, - service::service_fn, - Request, Response, + body::Incoming, server::conn::http1, service::service_fn, HeaderMap, Method, Request, Response, + Uri, }; use hyper_util::rt::TokioIo; +use opensearch::{http::transport::TransportBuilder, OpenSearch}; use tokio::{ net::{TcpListener, TcpStream}, - sync::{broadcast, mpsc}, + pin, runtime, select, + sync::{mpsc, watch}, + task, + time::sleep, }; -use tokio::runtime; +use super::client::TestClientBuilder; -pub struct Server { - addr: net::SocketAddr, - panic_rx: std_mpsc::Receiver<()>, - shutdown_tx: Option>, +#[derive(Clone)] +struct RequestState { + requests_tx: mpsc::UnboundedSender, + response_delay: Option, + shutdown_rx: watch::Receiver, } -impl Server { - pub fn addr(&self) -> net::SocketAddr { - self.addr - } +#[derive(Default)] +pub struct MockServerBuilder { + response_delay: Option, } -impl Drop for Server { - fn drop(&mut self) { - if let Some(tx) = self.shutdown_tx.take() { - tx.send(()).unwrap(); +impl MockServerBuilder { + pub fn response_delay(mut self, delay: Duration) -> Self { + self.response_delay = Some(delay); + self + } + + async fn handle_request( + req: Request, + state: RequestState, + ) -> anyhow::Result>> { + state.requests_tx.send(req.into())?; + if let Some(response_delay) = state.response_delay { + sleep(response_delay).await; } + Ok(Default::default()) + } - if !::std::thread::panicking() { - self.panic_rx - .recv_timeout(Duration::from_secs(3)) - .expect("test server should not panic"); + async fn serve_connection(io: TokioIo, state: RequestState) { + let mut shutdown_rx = state.shutdown_rx.clone(); + let conn = http1::Builder::new().serve_connection( + io, + service_fn(move |req| Self::handle_request(req, state.clone())), + ); + pin!(conn); + select! { + _ = conn.as_mut() => {}, + _ = shutdown_rx.changed() => conn.as_mut().graceful_shutdown() } } -} -pub fn http(func: F) -> Server -where - F: Fn(Request) -> Fut + Clone + Send + 'static, - Fut: Future> + Send + 'static, - B: Body + Send + 'static, - B::Data: Send, - B::Error: std::error::Error + Send + Sync, -{ - let thread_name = thread::current().name().unwrap_or("").to_owned(); - - thread::spawn(move || { + async fn serve(listener: TcpListener, state: RequestState) -> anyhow::Result<()> { + let mut shutdown_rx = state.shutdown_rx.clone(); + loop { + let (stream, _) = tokio::select! { + res = listener.accept() => res?, + _ = shutdown_rx.changed() => break + }; + let io = TokioIo::new(stream); + + task::spawn(Self::serve_connection(io, state.clone())); + } + Ok(()) + } + + fn start_inner(self, thread_name: String) -> anyhow::Result { let rt = runtime::Builder::new_current_thread() .enable_all() - .build() - .expect("new rt"); + .build()?; let _ = rt.enter(); - let (shutdown_tx, mut shutdown_rx) = broadcast::channel(1); - let listener = rt - .block_on(TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))) - .unwrap(); - let addr = listener.local_addr().unwrap(); - - let srv = async move { - loop { - let (stream, _) = tokio::select! { - res = listener.accept() => res?, - _ = shutdown_rx.recv() => break - }; - let io = TokioIo::new(stream); - - let mut func = func.clone(); - let mut shutdown_rx = shutdown_rx.resubscribe(); - - tokio::task::spawn(async move { - let conn = http1::Builder::new().serve_connection( - io, - service_fn(move |req| { - let func = func.clone(); - async move { Ok::<_, Infallible>(func(req).await) } - }), - ); - tokio::pin!(conn); - tokio::select! { - res = conn.as_mut() => {}, - _ = shutdown_rx.recv() => conn.as_mut().graceful_shutdown() - } - }); - } - Ok::<(), anyhow::Error>(()) - }; + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let (requests_tx, requests_rx) = mpsc::unbounded_channel(); + let listener = rt.block_on(TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))))?; + let addr = listener.local_addr()?; + + let srv = Self::serve( + listener, + RequestState { + requests_tx, + response_delay: self.response_delay, + shutdown_rx, + }, + ); let (panic_tx, panic_rx) = std_mpsc::channel(); - let thread_name = format!("test({})-support-server", thread_name); thread::Builder::new() - .name(thread_name) + .name(format!("test({})-support-server", thread_name)) .spawn(move || { rt.block_on(srv).unwrap(); let _ = panic_tx.send(()); - }) - .expect("thread spawn"); + })?; - Server { - addr, + Ok(MockServer { + uri: format!("http://{}", addr), + requests_rx, panic_rx, shutdown_tx: Some(shutdown_tx), + }) + } + + pub fn start(self) -> anyhow::Result { + let thread_name = thread::current().name().unwrap_or("").to_owned(); + + match thread::spawn(move || self.start_inner(thread_name)).join() { + Ok(r) => r, + Err(e) => Err(anyhow::anyhow!("MockServer startup panicked: {:?}", e)), } - }) - .join() - .unwrap() + } +} + +pub struct MockServer { + uri: String, + requests_rx: mpsc::UnboundedReceiver, + panic_rx: std_mpsc::Receiver<()>, + shutdown_tx: Option>, +} + +impl MockServer { + pub fn builder() -> MockServerBuilder { + MockServerBuilder::default() + } + + pub fn start() -> anyhow::Result { + Self::builder().start() + } + + pub fn client(&self) -> OpenSearch { + self.client_with(identity) + } + + pub fn client_with( + &self, + configurator: impl FnOnce(TransportBuilder) -> TransportBuilder, + ) -> OpenSearch { + self.client_builder().with(configurator).build() + } + + pub fn client_builder(&self) -> TestClientBuilder { + super::client::builder_with_url(&self.uri) + } + + pub async fn received_request(&mut self) -> anyhow::Result { + self.requests_rx + .recv() + .await + .ok_or_else(|| anyhow::anyhow!("no request received")) + } } -pub fn capturing_http() -> (Server, mpsc::UnboundedReceiver>) { - let (tx, rx) = mpsc::unbounded_channel(); - let server = http(move |req| { - let tx = tx.clone(); - async move { - tx.send(req).unwrap(); - empty_response() +impl Drop for MockServer { + fn drop(&mut self) { + if let Some(tx) = self.shutdown_tx.take() { + tx.send(true).unwrap(); + } + + if !::std::thread::panicking() { + self.panic_rx + .recv_timeout(Duration::from_secs(3)) + .expect("test server should not panic"); } - }); - (server, rx) + } +} + +pub struct ReceivedRequest { + method: Method, + uri: Uri, + headers: HeaderMap, } -pub fn empty_response() -> Response> { - Default::default() +impl ReceivedRequest { + pub fn method(&self) -> &Method { + &self.method + } + + pub fn path(&self) -> &str { + self.uri.path() + } + + pub fn query(&self) -> Option<&str> { + self.uri.query() + } + + pub fn header(&self, name: &str) -> Option<&str> { + self.headers.get(name).and_then(|v| v.to_str().ok()) + } +} + +impl From> for ReceivedRequest { + fn from(req: Request) -> Self { + ReceivedRequest { + method: req.method().clone(), + uri: req.uri().clone(), + headers: req.headers().clone(), + } + } } diff --git a/opensearch/tests/error.rs b/opensearch/tests/error.rs index 382eab69..eb97e65d 100644 --- a/opensearch/tests/error.rs +++ b/opensearch/tests/error.rs @@ -38,8 +38,7 @@ use serde_json::{json, Value}; /// Responses in the range 400-599 return Response body #[tokio::test] async fn bad_request_returns_response() -> anyhow::Result<()> { - let client = client::create_default(); - let response = client + let response = client::create() .explain(ExplainParts::IndexId("non_existent_index", "id")) .body(json!({})) .send() @@ -63,8 +62,7 @@ async fn bad_request_returns_response() -> anyhow::Result<()> { #[tokio::test] async fn deserialize_exception() -> anyhow::Result<()> { - let client = client::create_default(); - let response = client + let response = client::create() .explain(ExplainParts::IndexId("non_existent_index", "id")) .error_trace(true) .body(json!({})) diff --git a/yaml_test_runner/src/generator.rs b/yaml_test_runner/src/generator.rs index d4571bf8..ffc6a57e 100644 --- a/yaml_test_runner/src/generator.rs +++ b/yaml_test_runner/src/generator.rs @@ -204,7 +204,7 @@ impl<'a> YamlTests<'a> { fn should_skip_suite(&self) -> Option { if self.should_skip_test("*") { - Some(format!("it's included in skip.yml")) + Some("it's included in skip.yml".into()) } else if let Some(setup) = &self.setup { setup .steps @@ -508,7 +508,7 @@ fn write_mod_files(generated_dir: &Path, toplevel: bool) -> anyhow::Result<()> { mods.sort(); let path = generated_dir.join("mod.rs"); - let mut file = File::create(&path)?; + let mut file = File::create(path)?; let generated_mods: String = mods.join("\n"); file.write_all(generated_mods.as_bytes())?; Ok(()) diff --git a/yaml_test_runner/src/github.rs b/yaml_test_runner/src/github.rs index 0c1705b3..420b5f9d 100644 --- a/yaml_test_runner/src/github.rs +++ b/yaml_test_runner/src/github.rs @@ -66,7 +66,7 @@ pub fn download_test_suites(branch: &str, download_dir: &Path) -> anyhow::Result .build() .unwrap(); - let response = client.get(&url).send()?; + let response = client.get(url).send()?; let tar = GzDecoder::new(response); let mut archive = Archive::new(tar); diff --git a/yaml_test_runner/tests/common/macros.rs b/yaml_test_runner/tests/common/macros.rs index be3c2ed1..5a0ffd98 100644 --- a/yaml_test_runner/tests/common/macros.rs +++ b/yaml_test_runner/tests/common/macros.rs @@ -132,9 +132,9 @@ macro_rules! assert_match { macro_rules! assert_numeric_match { ($expected:expr, $actual:expr) => {{ if $expected.is_i64() { - crate::assert_match!($expected, $actual); + $crate::assert_match!($expected, $actual); } else { - crate::assert_match!($expected, $actual as f64); + $crate::assert_match!($expected, $actual as f64); } }}; }