Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor mock HTTP server tests #227

Merged
merged 2 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
version: 2.8.0
secured: true

- name: Run Tests (${{ matrix.test-args }})
- name: Run Tests
working-directory: client
run: cargo make test ${{ matrix.test-args }}
env:
Expand Down
10 changes: 3 additions & 7 deletions Makefile.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,14 @@ OPENSEARCH_URL = { value = "${OPENSEARCH_PROTOCOL}://localhost:9200", condition
category = "OpenSearch"
description = "Generates SSL certificates used for integration tests"
command = "bash"
args =["./.ci/generate-certs.sh"]
args = ["./.ci/generate-certs.sh"]

[tasks.run-opensearch]
category = "OpenSearch"
private = true
condition = { env_set = [ "STACK_VERSION"], env_false = ["CARGO_MAKE_CI"] }

[tasks.run-opensearch.linux]
command = "./.ci/run-opensearch.sh"

[tasks.run-opensearch.mac]
command = "./.ci/run-opensearch.sh"
command = "bash"
args = ["./.ci/run-opensearch.sh"]

[tasks.run-opensearch.windows]
script_runner = "cmd"
Expand Down
2 changes: 1 addition & 1 deletion api_generator/src/rest_spec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub fn download_specs(branch: &str, download_dir: &Path) -> anyhow::Result<()> {
.build()
.unwrap();

let response = client.get(&url).send()?;
let response = client.get(url).send()?;
let tar = GzDecoder::new(response);
let mut archive = Archive::new(tar);

Expand Down
1 change: 0 additions & 1 deletion opensearch/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ futures = "0.3.1"
http-body-util = "0.1.0"
hyper = { version = "1", features = ["full"] }
hyper-util = { version = "0.1", features = ["full"] }
regex="1.4"
sysinfo = "0.29.0"
test-case = "3"
textwrap = "0.16"
Expand Down
4 changes: 2 additions & 2 deletions opensearch/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl From<ClientCertificate> for Credentials {
}
}

#[cfg(any(feature = "aws-auth"))]
#[cfg(feature = "aws-auth")]
impl std::convert::TryFrom<&aws_types::SdkConfig> for Credentials {
type Error = super::Error;

Expand All @@ -107,7 +107,7 @@ impl std::convert::TryFrom<&aws_types::SdkConfig> for Credentials {
}
}

#[cfg(any(feature = "aws-auth"))]
#[cfg(feature = "aws-auth")]
impl std::convert::TryFrom<aws_types::SdkConfig> for Credentials {
type Error = super::Error;

Expand Down
13 changes: 13 additions & 0 deletions opensearch/src/models/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

#![allow(unused)]

use serde::Deserialize;

#[derive(Deserialize, Debug)]
Expand Down
85 changes: 33 additions & 52 deletions opensearch/tests/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,77 +29,58 @@
*/

pub mod common;
use common::*;

use crate::common::server::MockServer;
use opensearch::auth::Credentials;

use base64::{prelude::BASE64_STANDARD, write::EncoderWriter as Base64Encoder};
use std::io::Write;

#[tokio::test]
async fn basic_auth_header() -> anyhow::Result<()> {
let server = server::http(move |req| async move {
let mut header_value = b"Basic ".to_vec();
{
let mut encoder = Base64Encoder::new(&mut header_value, &BASE64_STANDARD);
write!(encoder, "username:password").unwrap();
}

assert_header_eq!(
req,
"authorization",
String::from_utf8(header_value).unwrap()
);
server::empty_response()
});

let builder = client::create_builder(format!("http://{}", server.addr()).as_ref())
.auth(Credentials::Basic("username".into(), "password".into()));

let client = client::create(builder);
let _response = client.ping().send().await?;
let mut server = MockServer::start()?;

let client =
server.client_with(|b| b.auth(Credentials::Basic("username".into(), "password".into())));

let _ = client.ping().send().await?;

let request = server.received_request().await?;

assert_eq!(
request.header("authorization"),
Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ=")
);

Ok(())
}

#[tokio::test]
async fn api_key_header() -> anyhow::Result<()> {
let server = server::http(move |req| async move {
let mut header_value = b"ApiKey ".to_vec();
{
let mut encoder = Base64Encoder::new(&mut header_value, &BASE64_STANDARD);
write!(encoder, "id:api_key").unwrap();
}

assert_header_eq!(
req,
"authorization",
String::from_utf8(header_value).unwrap()
);
server::empty_response()
});

let builder = client::create_builder(format!("http://{}", server.addr()).as_ref())
.auth(Credentials::ApiKey("id".into(), "api_key".into()));

let client = client::create(builder);
let _response = client.ping().send().await?;
let mut server = MockServer::start()?;

let client = server.client_with(|b| b.auth(Credentials::ApiKey("id".into(), "api_key".into())));

let _ = client.ping().send().await?;

let request = server.received_request().await?;

assert_eq!(
request.header("authorization"),
Some("ApiKey aWQ6YXBpX2tleQ==")
);

Ok(())
}

#[tokio::test]
async fn bearer_header() -> anyhow::Result<()> {
let server = server::http(move |req| async move {
assert_header_eq!(req, "authorization", "Bearer access_token");
server::empty_response()
});
let mut server = MockServer::start()?;

let client = server.client_with(|b| b.auth(Credentials::Bearer("access_token".into())));

let _ = client.ping().send().await?;

let builder = client::create_builder(format!("http://{}", server.addr()).as_ref())
.auth(Credentials::Bearer("access_token".into()));
let request = server.received_request().await?;

let client = client::create(builder);
let _response = client.ping().send().await?;
assert_eq!(request.header("authorization"), Some("Bearer access_token"));

Ok(())
}
Expand Down
137 changes: 64 additions & 73 deletions opensearch/tests/aws_auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,34 @@
#![cfg(feature = "aws-auth")]

pub mod common;
use aws_config::SdkConfig;
use aws_credential_types::provider::SharedCredentialsProvider;
use aws_credential_types::Credentials as AwsCredentials;
use aws_credential_types::{provider::SharedCredentialsProvider, Credentials as AwsCredentials};
use aws_smithy_async::time::StaticTimeSource;
use aws_types::region::Region;
use common::*;
use opensearch::{auth::Credentials, indices::IndicesCreateParts, OpenSearch};
use regex::Regex;
use reqwest::header::HOST;
use common::{server::MockServer, tracing_init};
use opensearch::{
http::{headers::HOST, transport::TransportBuilder},
indices::IndicesCreateParts,
};
use reqwest::header::HeaderValue;
use serde_json::json;
use std::convert::TryInto;
use test_case::test_case;

fn sigv4_config(transport: TransportBuilder, service_name: &str) -> TransportBuilder {
let aws_creds = AwsCredentials::new("test-access-key", "test-secret-key", None, None, "test");
let region = Region::new("ap-southeast-2");
let time_source = StaticTimeSource::from_secs(1673626117); // 2023-01-13 16:08:37 +0000

transport
.auth(opensearch::auth::Credentials::AwsSigV4(
SharedCredentialsProvider::new(aws_creds),
region,
))
.service_name(service_name)
.sigv4_time_source(time_source.into())
}

const LOCALHOST: HeaderValue = HeaderValue::from_static("localhost");

#[test_case("es", "10c9be415f4b9f15b12abbb16bd3e3730b2e6c76e0cf40db75d08a44ed04a3a1"; "when service name is es")]
#[test_case("aoss", "34903aef90423aa7dd60575d3d45316c6ef2d57bbe564a152b41bf8f5917abf6"; "when service name is aoss")]
#[test_case("arbitrary", "156e65c504ea2b2722a481b7515062e7692d27217b477828854e715f507e6a36"; "when service name is arbitrary")]
Expand All @@ -35,22 +50,12 @@ async fn aws_auth_signs_correctly(
) -> anyhow::Result<()> {
tracing_init();

let (server, mut rx) = server::capturing_http();
let mut server = MockServer::start()?;

let aws_creds = AwsCredentials::new("test-access-key", "test-secret-key", None, None, "test");
let region = Region::new("ap-southeast-2");
let time_source = StaticTimeSource::from_secs(1673626117); // 2023-01-13 16:08:37 +0000
let host = format!("aaabbbcccddd111222333.ap-southeast-2.{service_name}.amazonaws.com");

let transport_builder = client::create_builder(&format!("http://{}", server.addr()))
.auth(Credentials::AwsSigV4(
SharedCredentialsProvider::new(aws_creds),
region,
))
.service_name(service_name)
.sigv4_time_source(time_source.into())
.header(HOST, host.parse().unwrap());
let client = client::create(transport_builder);
let client =
server.client_with(|b| sigv4_config(b, service_name).header(HOST, host.parse().unwrap()));

let _ = client
.indices()
Expand All @@ -74,59 +79,51 @@ async fn aws_auth_signs_correctly(
.send()
.await?;

let sent_req = rx.recv().await.expect("should have sent a request");
let sent_req = server.received_request().await?;

assert_header_eq!(sent_req, "accept", "application/json");
assert_header_eq!(sent_req, "content-type", "application/json");
assert_header_eq!(sent_req, "host", host);
assert_header_eq!(sent_req, "x-amz-date", "20230113T160837Z");
assert_header_eq!(
sent_req,
"x-amz-content-sha256",
"4c770eaed349122a28302ff73d34437cad600acda5a9dd373efc7da2910f8564"
assert_eq!(sent_req.header("accept"), Some("application/json"));
assert_eq!(sent_req.header("content-type"), Some("application/json"));
assert_eq!(sent_req.header("host"), Some(host.as_str()));
assert_eq!(sent_req.header("x-amz-date"), Some("20230113T160837Z"));
assert_eq!(
sent_req.header("x-amz-content-sha256"),
Some("4c770eaed349122a28302ff73d34437cad600acda5a9dd373efc7da2910f8564")
);
assert_header_eq!(sent_req, "authorization", format!("AWS4-HMAC-SHA256 Credential=test-access-key/20230113/ap-southeast-2/{service_name}/aws4_request, SignedHeaders=accept;content-type;host;x-amz-content-sha256;x-amz-date, Signature={expected_signature}"));
assert_eq!(sent_req.header("authorization"), Some(format!("AWS4-HMAC-SHA256 Credential=test-access-key/20230113/ap-southeast-2/{service_name}/aws4_request, SignedHeaders=accept;content-type;host;x-amz-content-sha256;x-amz-date, Signature={expected_signature}").as_str()));

Ok(())
}

#[tokio::test]
async fn aws_auth_get() -> anyhow::Result<()> {
let server = server::http(move |req| async move {
let authorization_header = req.headers()["authorization"].to_str().unwrap();
let re = Regex::new(r"^AWS4-HMAC-SHA256 Credential=id/\d*/us-west-1/custom/aws4_request, SignedHeaders=accept;content-type;host;x-amz-content-sha256;x-amz-date, Signature=[a-f,0-9].*$").unwrap();
assert!(
re.is_match(authorization_header),
"{}",
authorization_header
);
assert_header_eq!(
req,
"x-amz-content-sha256",
"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
); // SHA of empty string
server::empty_response()
});

let client = create_aws_client(format!("http://{}", server.addr()).as_ref())?;
let _response = client.ping().send().await?;
tracing_init();

let mut server = MockServer::start()?;

let client = server.client_with(|b| sigv4_config(b, "custom").header(HOST, LOCALHOST));

let _ = client.ping().send().await?;

let sent_req = server.received_request().await?;

assert_eq!(sent_req.header("authorization"), Some("AWS4-HMAC-SHA256 Credential=test-access-key/20230113/ap-southeast-2/custom/aws4_request, SignedHeaders=accept;content-type;host;x-amz-content-sha256;x-amz-date, Signature=e5aa6e5d9e1b86b86ed31fbb10dd62b4e93423b77830f8189701421d3e9f65bd"));
assert_eq!(
sent_req.header("x-amz-content-sha256"),
Some("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855")
); // SHA of zero-length body

Ok(())
}

#[tokio::test]
async fn aws_auth_post() -> anyhow::Result<()> {
let server = server::http(move |req| async move {
assert_header_eq!(
req,
"x-amz-content-sha256",
"f3a842f988a653a734ebe4e57c45f19293a002241a72f0b3abbff71e4f5297b9"
); // SHA of the JSON
server::empty_response()
});

let client = create_aws_client(format!("http://{}", server.addr()).as_ref())?;
client
tracing_init();

let mut server = MockServer::start()?;

let client = server.client_with(|b| sigv4_config(b, "custom").header(HOST, LOCALHOST));

let _ = client
.index(opensearch::IndexParts::Index("movies"))
.body(serde_json::json!({
"title": "Moneyball",
Expand All @@ -137,18 +134,12 @@ async fn aws_auth_post() -> anyhow::Result<()> {
.send()
.await?;

Ok(())
}
let sent_req = server.received_request().await?;

fn create_aws_client(addr: &str) -> anyhow::Result<OpenSearch> {
let aws_creds = AwsCredentials::new("id", "secret", None, None, "token");
let creds_provider = SharedCredentialsProvider::new(aws_creds);
let aws_config = SdkConfig::builder()
.credentials_provider(creds_provider)
.region(Region::new("us-west-1"))
.build();
let builder = client::create_builder(addr)
.auth(aws_config.clone().try_into()?)
.service_name("custom");
Ok(client::create(builder))
assert_eq!(
sent_req.header("x-amz-content-sha256"),
Some("f3a842f988a653a734ebe4e57c45f19293a002241a72f0b3abbff71e4f5297b9")
); // SHA of the JSON

Ok(())
}
Loading
Loading