diff --git a/aws_config.go b/aws_config.go index 43b5316f..832811a5 100644 --- a/aws_config.go +++ b/aws_config.go @@ -150,9 +150,16 @@ func GetAwsAccountIDAndPartition(ctx context.Context, awsConfig aws.Config, c *C } func commonLoadOptions(c *Config) ([]func(*config.LoadOptions) error, error) { - httpClient, err := defaultHttpClient(c) - if err != nil { - return nil, err + var err error + var httpClient config.HTTPClient + + if v := c.HTTPClient; v == nil { + httpClient, err = defaultHttpClient(c) + if err != nil { + return nil, err + } + } else { + httpClient = v } apiOptions := make([]func(*middleware.Stack) error, 0) diff --git a/http_client.go b/http_client.go index 6fac1421..540f1704 100644 --- a/http_client.go +++ b/http_client.go @@ -1,32 +1,17 @@ package awsbase import ( - "fmt" - "net/http" - "net/url" - awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" ) func defaultHttpClient(c *config.Config) (*awshttp.BuildableClient, error) { - var err error + opts, err := c.HTTPTransportOptions() + if err != nil { + return nil, err + } - httpClient := awshttp.NewBuildableClient(). - WithTransportOptions(func(tr *http.Transport) { - if c.Insecure { - tlsConfig := tr.TLSClientConfig - tlsConfig.InsecureSkipVerify = true - } - if c.HTTPProxy != "" { - var proxyUrl *url.URL - proxyUrl, parseErr := url.Parse(c.HTTPProxy) - if parseErr != nil { - err = fmt.Errorf("error parsing HTTP proxy URL: %w", parseErr) - } - tr.Proxy = http.ProxyURL(proxyUrl) - } - }) + httpClient := awshttp.NewBuildableClient().WithTransportOptions(opts) return httpClient, err } diff --git a/internal/config/config.go b/internal/config/config.go index 0887b3b5..d401e3c6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,10 +2,14 @@ package config import ( "bytes" + "crypto/tls" "fmt" + "net/http" + "net/url" "os" "time" + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" "github.com/hashicorp/aws-sdk-go-base/v2/internal/expand" ) @@ -21,6 +25,7 @@ type Config struct { EC2MetadataServiceEnableState imds.ClientEnableState EC2MetadataServiceEndpoint string EC2MetadataServiceEndpointMode string + HTTPClient *http.Client HTTPProxy string IamEndpoint string Insecure bool @@ -78,6 +83,41 @@ func (c Config) CustomCABundleReader() (*bytes.Reader, error) { return bytes.NewReader(bundle), nil } +// HTTPTransportOptions returns functional options that configures an http.Transport. +// The returned options function is called on both AWS SDKv1 and v2 default HTTP clients. +func (c Config) HTTPTransportOptions() (func(*http.Transport), error) { + var err error + var proxyUrl *url.URL + if c.HTTPProxy != "" { + proxyUrl, err = url.Parse(c.HTTPProxy) + if err != nil { + return nil, fmt.Errorf("error parsing HTTP proxy URL: %w", err) + } + } + + opts := func(tr *http.Transport) { + tr.MaxIdleConnsPerHost = awshttp.DefaultHTTPTransportMaxIdleConnsPerHost + + tlsConfig := tr.TLSClientConfig + if tlsConfig == nil { + tlsConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + } + tr.TLSClientConfig = tlsConfig + } + + if c.Insecure { + tr.TLSClientConfig.InsecureSkipVerify = true + } + + if proxyUrl != nil { + tr.Proxy = http.ProxyURL(proxyUrl) + } + } + + return opts, nil +} + func (c Config) ResolveSharedConfigFiles() ([]string, error) { v, err := expand.FilePaths(c.SharedConfigFiles) if err != nil { diff --git a/v2/awsv1shim/http_client.go b/v2/awsv1shim/http_client.go index af18b020..f1742422 100644 --- a/v2/awsv1shim/http_client.go +++ b/v2/awsv1shim/http_client.go @@ -1,41 +1,20 @@ package awsv1shim import ( - "crypto/tls" - "fmt" "net/http" - "net/url" - awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" "github.com/hashicorp/go-cleanhttp" ) func defaultHttpClient(c *config.Config) (*http.Client, error) { - httpClient := cleanhttp.DefaultPooledClient() - transport := httpClient.Transport.(*http.Transport) - - transport.MaxIdleConnsPerHost = awshttp.DefaultHTTPTransportMaxIdleConnsPerHost - - tlsConfig := transport.TLSClientConfig - if tlsConfig == nil { - tlsConfig = &tls.Config{} - transport.TLSClientConfig = tlsConfig - } - tlsConfig.MinVersion = tls.VersionTLS12 - - if c.Insecure { - tlsConfig.InsecureSkipVerify = true + opts, err := c.HTTPTransportOptions() + if err != nil { + return nil, err } - if c.HTTPProxy != "" { - proxyUrl, err := url.Parse(c.HTTPProxy) - if err != nil { - return nil, fmt.Errorf("error parsing HTTP proxy URL: %w", err) - } - - transport.Proxy = http.ProxyURL(proxyUrl) - } + httpClient := cleanhttp.DefaultPooledClient() + opts(httpClient.Transport.(*http.Transport)) return httpClient, nil } diff --git a/v2/awsv1shim/session.go b/v2/awsv1shim/session.go index 170ce08b..e7093eb7 100644 --- a/v2/awsv1shim/session.go +++ b/v2/awsv1shim/session.go @@ -31,9 +31,12 @@ func getSessionOptions(awsC *awsv2.Config, c *awsbase.Config) (*session.Options, return nil, fmt.Errorf("error resolving dual-stack endpoint configuration: %w", err) } - httpClient, err := defaultHttpClient(c) - if err != nil { - return nil, err + httpClient := c.HTTPClient + if httpClient == nil { + httpClient, err = defaultHttpClient(c) + if err != nil { + return nil, err + } } options := &session.Options{