Skip to content

Commit

Permalink
Refactor mock HTTP server tests (#227) (#231)
Browse files Browse the repository at this point in the history
* Refactor mock http server tests

Signed-off-by: Thomas Farr <tsfarr@amazon.com>

* de-dup code

Signed-off-by: Thomas Farr <tsfarr@amazon.com>

---------

Signed-off-by: Thomas Farr <tsfarr@amazon.com>
(cherry picked from commit 1035d4e)

Co-authored-by: Thomas Farr <tsfarr@amazon.com>
  • Loading branch information
opensearch-trigger-bot[bot] and Xtansia committed Dec 12, 2023
1 parent 688c31f commit 7004d29
Show file tree
Hide file tree
Showing 17 changed files with 459 additions and 406 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
version: 2.8.0
secured: true

- name: Run Tests (${{ matrix.test-args }})
- name: Run Tests
working-directory: client
run: cargo make test ${{ matrix.test-args }}
env:
Expand Down
10 changes: 3 additions & 7 deletions Makefile.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,14 @@ OPENSEARCH_URL = { value = "${OPENSEARCH_PROTOCOL}://localhost:9200", condition
category = "OpenSearch"
description = "Generates SSL certificates used for integration tests"
command = "bash"
args =["./.ci/generate-certs.sh"]
args = ["./.ci/generate-certs.sh"]

[tasks.run-opensearch]
category = "OpenSearch"
private = true
condition = { env_set = [ "STACK_VERSION"], env_false = ["CARGO_MAKE_CI"] }

[tasks.run-opensearch.linux]
command = "./.ci/run-opensearch.sh"

[tasks.run-opensearch.mac]
command = "./.ci/run-opensearch.sh"
command = "bash"
args = ["./.ci/run-opensearch.sh"]

[tasks.run-opensearch.windows]
script_runner = "cmd"
Expand Down
2 changes: 1 addition & 1 deletion api_generator/src/rest_spec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub fn download_specs(branch: &str, download_dir: &Path) -> anyhow::Result<()> {
.build()
.unwrap();

let response = client.get(&url).send()?;
let response = client.get(url).send()?;
let tar = GzDecoder::new(response);
let mut archive = Archive::new(tar);

Expand Down
1 change: 0 additions & 1 deletion opensearch/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ futures = "0.3.1"
http-body-util = "0.1.0"
hyper = { version = "1", features = ["full"] }
hyper-util = { version = "0.1", features = ["full"] }
regex="1.4"
sysinfo = "0.29.0"
test-case = "3"
textwrap = "0.16"
Expand Down
4 changes: 2 additions & 2 deletions opensearch/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl From<ClientCertificate> for Credentials {
}
}

#[cfg(any(feature = "aws-auth"))]
#[cfg(feature = "aws-auth")]
impl std::convert::TryFrom<&aws_types::SdkConfig> for Credentials {
type Error = super::Error;

Expand All @@ -107,7 +107,7 @@ impl std::convert::TryFrom<&aws_types::SdkConfig> for Credentials {
}
}

#[cfg(any(feature = "aws-auth"))]
#[cfg(feature = "aws-auth")]
impl std::convert::TryFrom<aws_types::SdkConfig> for Credentials {
type Error = super::Error;

Expand Down
13 changes: 13 additions & 0 deletions opensearch/src/models/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

#![allow(unused)]

use serde::Deserialize;

#[derive(Deserialize, Debug)]
Expand Down
85 changes: 33 additions & 52 deletions opensearch/tests/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,77 +29,58 @@
*/

pub mod common;
use common::*;

use crate::common::server::MockServer;
use opensearch::auth::Credentials;

use base64::{prelude::BASE64_STANDARD, write::EncoderWriter as Base64Encoder};
use std::io::Write;

#[tokio::test]
async fn basic_auth_header() -> anyhow::Result<()> {
let server = server::http(move |req| async move {
let mut header_value = b"Basic ".to_vec();
{
let mut encoder = Base64Encoder::new(&mut header_value, &BASE64_STANDARD);
write!(encoder, "username:password").unwrap();
}

assert_header_eq!(
req,
"authorization",
String::from_utf8(header_value).unwrap()
);
server::empty_response()
});

let builder = client::create_builder(format!("http://{}", server.addr()).as_ref())
.auth(Credentials::Basic("username".into(), "password".into()));

let client = client::create(builder);
let _response = client.ping().send().await?;
let mut server = MockServer::start()?;

let client =
server.client_with(|b| b.auth(Credentials::Basic("username".into(), "password".into())));

let _ = client.ping().send().await?;

let request = server.received_request().await?;

assert_eq!(
request.header("authorization"),
Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ=")
);

Ok(())
}

#[tokio::test]
async fn api_key_header() -> anyhow::Result<()> {
let server = server::http(move |req| async move {
let mut header_value = b"ApiKey ".to_vec();
{
let mut encoder = Base64Encoder::new(&mut header_value, &BASE64_STANDARD);
write!(encoder, "id:api_key").unwrap();
}

assert_header_eq!(
req,
"authorization",
String::from_utf8(header_value).unwrap()
);
server::empty_response()
});

let builder = client::create_builder(format!("http://{}", server.addr()).as_ref())
.auth(Credentials::ApiKey("id".into(), "api_key".into()));

let client = client::create(builder);
let _response = client.ping().send().await?;
let mut server = MockServer::start()?;

let client = server.client_with(|b| b.auth(Credentials::ApiKey("id".into(), "api_key".into())));

let _ = client.ping().send().await?;

let request = server.received_request().await?;

assert_eq!(
request.header("authorization"),
Some("ApiKey aWQ6YXBpX2tleQ==")
);

Ok(())
}

#[tokio::test]
async fn bearer_header() -> anyhow::Result<()> {
let server = server::http(move |req| async move {
assert_header_eq!(req, "authorization", "Bearer access_token");
server::empty_response()
});
let mut server = MockServer::start()?;

let client = server.client_with(|b| b.auth(Credentials::Bearer("access_token".into())));

let _ = client.ping().send().await?;

let builder = client::create_builder(format!("http://{}", server.addr()).as_ref())
.auth(Credentials::Bearer("access_token".into()));
let request = server.received_request().await?;

let client = client::create(builder);
let _response = client.ping().send().await?;
assert_eq!(request.header("authorization"), Some("Bearer access_token"));

Ok(())
}
Expand Down
137 changes: 64 additions & 73 deletions opensearch/tests/aws_auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,34 @@
#![cfg(feature = "aws-auth")]

pub mod common;
use aws_config::SdkConfig;
use aws_credential_types::provider::SharedCredentialsProvider;
use aws_credential_types::Credentials as AwsCredentials;
use aws_credential_types::{provider::SharedCredentialsProvider, Credentials as AwsCredentials};
use aws_smithy_async::time::StaticTimeSource;
use aws_types::region::Region;
use common::*;
use opensearch::{auth::Credentials, indices::IndicesCreateParts, OpenSearch};
use regex::Regex;
use reqwest::header::HOST;
use common::{server::MockServer, tracing_init};
use opensearch::{
http::{headers::HOST, transport::TransportBuilder},
indices::IndicesCreateParts,
};
use reqwest::header::HeaderValue;
use serde_json::json;
use std::convert::TryInto;
use test_case::test_case;

fn sigv4_config(transport: TransportBuilder, service_name: &str) -> TransportBuilder {
let aws_creds = AwsCredentials::new("test-access-key", "test-secret-key", None, None, "test");
let region = Region::new("ap-southeast-2");
let time_source = StaticTimeSource::from_secs(1673626117); // 2023-01-13 16:08:37 +0000

transport
.auth(opensearch::auth::Credentials::AwsSigV4(
SharedCredentialsProvider::new(aws_creds),
region,
))
.service_name(service_name)
.sigv4_time_source(time_source.into())
}

const LOCALHOST: HeaderValue = HeaderValue::from_static("localhost");

#[test_case("es", "10c9be415f4b9f15b12abbb16bd3e3730b2e6c76e0cf40db75d08a44ed04a3a1"; "when service name is es")]
#[test_case("aoss", "34903aef90423aa7dd60575d3d45316c6ef2d57bbe564a152b41bf8f5917abf6"; "when service name is aoss")]
#[test_case("arbitrary", "156e65c504ea2b2722a481b7515062e7692d27217b477828854e715f507e6a36"; "when service name is arbitrary")]
Expand All @@ -35,22 +50,12 @@ async fn aws_auth_signs_correctly(
) -> anyhow::Result<()> {
tracing_init();

let (server, mut rx) = server::capturing_http();
let mut server = MockServer::start()?;

let aws_creds = AwsCredentials::new("test-access-key", "test-secret-key", None, None, "test");
let region = Region::new("ap-southeast-2");
let time_source = StaticTimeSource::from_secs(1673626117); // 2023-01-13 16:08:37 +0000
let host = format!("aaabbbcccddd111222333.ap-southeast-2.{service_name}.amazonaws.com");

let transport_builder = client::create_builder(&format!("http://{}", server.addr()))
.auth(Credentials::AwsSigV4(
SharedCredentialsProvider::new(aws_creds),
region,
))
.service_name(service_name)
.sigv4_time_source(time_source.into())
.header(HOST, host.parse().unwrap());
let client = client::create(transport_builder);
let client =
server.client_with(|b| sigv4_config(b, service_name).header(HOST, host.parse().unwrap()));

let _ = client
.indices()
Expand All @@ -74,59 +79,51 @@ async fn aws_auth_signs_correctly(
.send()
.await?;

let sent_req = rx.recv().await.expect("should have sent a request");
let sent_req = server.received_request().await?;

assert_header_eq!(sent_req, "accept", "application/json");
assert_header_eq!(sent_req, "content-type", "application/json");
assert_header_eq!(sent_req, "host", host);
assert_header_eq!(sent_req, "x-amz-date", "20230113T160837Z");
assert_header_eq!(
sent_req,
"x-amz-content-sha256",
"4c770eaed349122a28302ff73d34437cad600acda5a9dd373efc7da2910f8564"
assert_eq!(sent_req.header("accept"), Some("application/json"));
assert_eq!(sent_req.header("content-type"), Some("application/json"));
assert_eq!(sent_req.header("host"), Some(host.as_str()));
assert_eq!(sent_req.header("x-amz-date"), Some("20230113T160837Z"));
assert_eq!(
sent_req.header("x-amz-content-sha256"),
Some("4c770eaed349122a28302ff73d34437cad600acda5a9dd373efc7da2910f8564")
);
assert_header_eq!(sent_req, "authorization", format!("AWS4-HMAC-SHA256 Credential=test-access-key/20230113/ap-southeast-2/{service_name}/aws4_request, SignedHeaders=accept;content-type;host;x-amz-content-sha256;x-amz-date, Signature={expected_signature}"));
assert_eq!(sent_req.header("authorization"), Some(format!("AWS4-HMAC-SHA256 Credential=test-access-key/20230113/ap-southeast-2/{service_name}/aws4_request, SignedHeaders=accept;content-type;host;x-amz-content-sha256;x-amz-date, Signature={expected_signature}").as_str()));

Ok(())
}

#[tokio::test]
async fn aws_auth_get() -> anyhow::Result<()> {
let server = server::http(move |req| async move {
let authorization_header = req.headers()["authorization"].to_str().unwrap();
let re = Regex::new(r"^AWS4-HMAC-SHA256 Credential=id/\d*/us-west-1/custom/aws4_request, SignedHeaders=accept;content-type;host;x-amz-content-sha256;x-amz-date, Signature=[a-f,0-9].*$").unwrap();
assert!(
re.is_match(authorization_header),
"{}",
authorization_header
);
assert_header_eq!(
req,
"x-amz-content-sha256",
"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
); // SHA of empty string
server::empty_response()
});

let client = create_aws_client(format!("http://{}", server.addr()).as_ref())?;
let _response = client.ping().send().await?;
tracing_init();

let mut server = MockServer::start()?;

let client = server.client_with(|b| sigv4_config(b, "custom").header(HOST, LOCALHOST));

let _ = client.ping().send().await?;

let sent_req = server.received_request().await?;

assert_eq!(sent_req.header("authorization"), Some("AWS4-HMAC-SHA256 Credential=test-access-key/20230113/ap-southeast-2/custom/aws4_request, SignedHeaders=accept;content-type;host;x-amz-content-sha256;x-amz-date, Signature=e5aa6e5d9e1b86b86ed31fbb10dd62b4e93423b77830f8189701421d3e9f65bd"));
assert_eq!(
sent_req.header("x-amz-content-sha256"),
Some("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855")
); // SHA of zero-length body

Ok(())
}

#[tokio::test]
async fn aws_auth_post() -> anyhow::Result<()> {
let server = server::http(move |req| async move {
assert_header_eq!(
req,
"x-amz-content-sha256",
"f3a842f988a653a734ebe4e57c45f19293a002241a72f0b3abbff71e4f5297b9"
); // SHA of the JSON
server::empty_response()
});

let client = create_aws_client(format!("http://{}", server.addr()).as_ref())?;
client
tracing_init();

let mut server = MockServer::start()?;

let client = server.client_with(|b| sigv4_config(b, "custom").header(HOST, LOCALHOST));

let _ = client
.index(opensearch::IndexParts::Index("movies"))
.body(serde_json::json!({
"title": "Moneyball",
Expand All @@ -137,18 +134,12 @@ async fn aws_auth_post() -> anyhow::Result<()> {
.send()
.await?;

Ok(())
}
let sent_req = server.received_request().await?;

fn create_aws_client(addr: &str) -> anyhow::Result<OpenSearch> {
let aws_creds = AwsCredentials::new("id", "secret", None, None, "token");
let creds_provider = SharedCredentialsProvider::new(aws_creds);
let aws_config = SdkConfig::builder()
.credentials_provider(creds_provider)
.region(Region::new("us-west-1"))
.build();
let builder = client::create_builder(addr)
.auth(aws_config.clone().try_into()?)
.service_name("custom");
Ok(client::create(builder))
assert_eq!(
sent_req.header("x-amz-content-sha256"),
Some("f3a842f988a653a734ebe4e57c45f19293a002241a72f0b3abbff71e4f5297b9")
); // SHA of the JSON

Ok(())
}
Loading

0 comments on commit 7004d29

Please sign in to comment.