diff --git a/config/config.go b/config/config.go index 58c5c7a..5f0ccfd 100755 --- a/config/config.go +++ b/config/config.go @@ -287,6 +287,7 @@ type Log struct { // Transport exposes a subset of Transport parameters. reference: https://github.com/golang/go/blob/master/src/net/http/transport.go#L95 type Transport struct { + DialContext DialContext `yaml:"dialContext,omitempty"` TLSHandshakeTimeout time.Duration `yaml:"tlsHandshakeTimeout,omitempty"` DisableKeepAlives bool `yaml:"disableKeepAlives,omitempty"` DisableCompression bool `yaml:"disableCompression,omitempty"` @@ -302,6 +303,11 @@ type Transport struct { ForceAttemptHTTP2 bool `yaml:"forceAttemptHTTP2,omitempty"` } +// DialContext exposes a subset of DialContext parameters. reference: https://github.com/golang/go/blob/master/src/net/http/transport.go#L318 +type DialContext struct { + Timeout time.Duration `yaml:"timeout"` +} + // OriginLog represents log configuration from origin type OriginLog struct { StatusCode StatusCode `yaml:"statusCode"` diff --git a/config/config_test.go b/config/config_test.go index 1308bfb..07e96fb 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -140,6 +140,9 @@ func TestNew(t *testing.T) { WriteBufferSize: 0, ReadBufferSize: 0, ForceAttemptHTTP2: true, + DialContext: DialContext{ + Timeout: 1 * time.Second, + }, }, OriginLog: OriginLog{ StatusCode: StatusCode{ diff --git a/handler/handler.go b/handler/handler.go index 82a6011..973cb21 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "io/ioutil" + "net" "net/http" "net/http/httputil" "strings" @@ -88,7 +89,7 @@ func New(cfg config.Proxy, bp httputil.BufferPool, prov service.Authorizationd) ModifyResponse: modifyResponse, Transport: &transport{ prov: prov, - RoundTripper: transportFromCfg(cfg.Transport), + RoundTripper: updateDialContext(transportFromCfg(cfg.Transport), cfg.Transport.DialContext.Timeout), cfg: cfg, noAuthPaths: mapPathToAssertion(cfg.NoAuthPaths), }, @@ -96,6 +97,15 @@ func New(cfg config.Proxy, bp httputil.BufferPool, prov service.Authorizationd) } } +func updateDialContext(t *http.Transport, dialTimeout time.Duration) *http.Transport { + if dialTimeout != time.Duration(0) { + t.DialContext = (&net.Dialer{ + Timeout: dialTimeout, + }).DialContext + } + return t +} + func transportFromCfg(cfg config.Transport) *http.Transport { isZero := func(v interface{}) bool { switch v.(type) { diff --git a/handler/handler_test.go b/handler/handler_test.go index e16ad17..5401a6e 100644 --- a/handler/handler_test.go +++ b/handler/handler_test.go @@ -5,6 +5,8 @@ import ( "context" "encoding/json" "fmt" + "math" + "net" "net/http" "net/http/httptest" "net/http/httputil" @@ -565,6 +567,61 @@ func TestNew(t *testing.T) { } } +func Test_updateDialContext(t *testing.T) { + type args struct { + cfg *http.Transport + dialTimeout time.Duration + } + tests := []struct { + name string + args args + want *http.Transport + }{ + { + name: "check dialContext.timeout == 0 is not used", + args: args{ + cfg: &http.Transport{}, + dialTimeout: 0, + }, + want: &http.Transport{}, + }, + { + name: "check dialContext.timeout != 0 is used", + args: args{ + cfg: &http.Transport{}, + dialTimeout: 10 * time.Second, + }, + want: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 10 * time.Second, + }).DialContext, + }, + }, + { + name: "check if dialContext.timeout is negative, timeout is math.MaxInt64", + args: args{ + cfg: &http.Transport{}, + dialTimeout: -1, + }, + want: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: math.MaxInt64, + }).DialContext, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := updateDialContext(tt.args.cfg, tt.args.dialTimeout) + p1 := reflect.ValueOf(got.DialContext).Pointer() + p2 := reflect.ValueOf(tt.want.DialContext).Pointer() + if p1 != p2 { + t.Errorf("updateDialContext() = %+v, want %+v", p1, p2) + } + }) + } +} + func Test_transportFromCfg(t *testing.T) { type args struct { cfg config.Transport diff --git a/test/data/example_config.yaml b/test/data/example_config.yaml index 6723710..95b6f0c 100644 --- a/test/data/example_config.yaml +++ b/test/data/example_config.yaml @@ -56,6 +56,8 @@ proxy: writeBufferSize: 0 readBufferSize: 0 forceAttemptHTTP2: true + dialContext: + timeout: "1s" originLog: statusCode: enable: true