Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: add support for socket options #12702

Merged
merged 2 commits into from
Mar 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 66 additions & 6 deletions pkg/transport/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package transport

import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
Expand All @@ -39,18 +40,66 @@ import (

// NewListener creates a new listner.
func NewListener(addr, scheme string, tlsinfo *TLSInfo) (l net.Listener, err error) {
if l, err = newListener(addr, scheme); err != nil {
return nil, err
}
return wrapTLS(scheme, tlsinfo, l)
return newListener(addr, scheme, WithTLSInfo(tlsinfo))
}

// NewListenerWithOpts creates a new listener which accpets listener options.
func NewListenerWithOpts(addr, scheme string, opts ...ListenerOption) (net.Listener, error) {
return newListener(addr, scheme, opts...)
}

func newListener(addr string, scheme string) (net.Listener, error) {
func newListener(addr, scheme string, opts ...ListenerOption) (net.Listener, error) {
if scheme == "unix" || scheme == "unixs" {
// unix sockets via unix://laddr
return NewUnixListener(addr)
}
return net.Listen("tcp", addr)

lnOpts := newListenOpts(opts...)

switch {
case lnOpts.IsSocketOpts():
// new ListenConfig with socket options.
config, err := newListenConfig(lnOpts.socketOpts)
if err != nil {
return nil, err
}
lnOpts.ListenConfig = config
// check for timeout
fallthrough
case lnOpts.IsTimeout(), lnOpts.IsSocketOpts():
// timeout listener with socket options.
ln, err := lnOpts.ListenConfig.Listen(context.TODO(), "tcp", addr)
if err != nil {
return nil, err
}
lnOpts.Listener = &rwTimeoutListener{
Listener: ln,
readTimeout: lnOpts.readTimeout,
writeTimeout: lnOpts.writeTimeout,
}
case lnOpts.IsTimeout():
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
lnOpts.Listener = &rwTimeoutListener{
Listener: ln,
readTimeout: lnOpts.readTimeout,
writeTimeout: lnOpts.writeTimeout,
}
default:
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
lnOpts.Listener = ln
}

// only skip if not passing TLSInfo
if lnOpts.skipTLSInfoCheck && !lnOpts.IsTLS() {
return lnOpts.Listener, nil
}
return wrapTLS(scheme, lnOpts.tlsInfo, lnOpts.Listener)
}

func wrapTLS(scheme string, tlsinfo *TLSInfo, l net.Listener) (net.Listener, error) {
Expand All @@ -63,6 +112,17 @@ func wrapTLS(scheme string, tlsinfo *TLSInfo, l net.Listener) (net.Listener, err
return newTLSListener(l, tlsinfo, checkSAN)
}

func newListenConfig(sopts *SocketOpts) (net.ListenConfig, error) {
lc := net.ListenConfig{}
if sopts != nil {
ctls := getControls(sopts)
if len(ctls) > 0 {
lc.Control = ctls.Control
}
}
return lc, nil
}

type TLSInfo struct {
CertFile string
KeyFile string
Expand Down
76 changes: 76 additions & 0 deletions pkg/transport/listener_opts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package transport

import (
"net"
"time"
)

type ListenerOptions struct {
Listener net.Listener
ListenConfig net.ListenConfig

socketOpts *SocketOpts
tlsInfo *TLSInfo
skipTLSInfoCheck bool
writeTimeout time.Duration
readTimeout time.Duration
}

func newListenOpts(opts ...ListenerOption) *ListenerOptions {
lnOpts := &ListenerOptions{}
lnOpts.applyOpts(opts)
return lnOpts
}

func (lo *ListenerOptions) applyOpts(opts []ListenerOption) {
for _, opt := range opts {
opt(lo)
}
}

// IsTimeout returns true if the listener has a read/write timeout defined.
func (lo *ListenerOptions) IsTimeout() bool { return lo.readTimeout != 0 || lo.writeTimeout != 0 }

// IsSocketOpts returns true if the listener options includes socket options.
func (lo *ListenerOptions) IsSocketOpts() bool {
if lo.socketOpts == nil {
return false
}
return lo.socketOpts.ReusePort == true || lo.socketOpts.ReuseAddress == true
}

// IsTLS returns true if listner options includes TLSInfo.
func (lo *ListenerOptions) IsTLS() bool {
if lo.tlsInfo == nil {
return false
}
return lo.tlsInfo.Empty() == false
}

// ListenerOption are options which can be applied to the listener.
type ListenerOption func(*ListenerOptions)

// WithTimeout allows for a read or write timeout to be applied to the listener.
func WithTimeout(read, write time.Duration) ListenerOption {
return func(lo *ListenerOptions) {
lo.writeTimeout = write
lo.readTimeout = read
}
}

// WithSocketOpts defines socket options that will be applied to the listener.
func WithSocketOpts(s *SocketOpts) ListenerOption {
return func(lo *ListenerOptions) { lo.socketOpts = s }
}

// WithTLSInfo adds TLS credentials to the listener.
func WithTLSInfo(t *TLSInfo) ListenerOption {
return func(lo *ListenerOptions) { lo.tlsInfo = t }
}

// WithSkipTLSInfoCheck when true a transport can be created with an https scheme
// without passing TLSInfo, circumventing not presented error. Skipping this check
// also requires that TLSInfo is not passed.
func WithSkipTLSInfoCheck(skip bool) ListenerOption {
return func(lo *ListenerOptions) { lo.skipTLSInfoCheck = skip }
}
173 changes: 173 additions & 0 deletions pkg/transport/listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,161 @@ func TestNewListenerTLSInfo(t *testing.T) {
testNewListenerTLSInfoAccept(t, *tlsInfo)
}

func TestNewListenerWithOpts(t *testing.T) {
tlsInfo, del, err := createSelfCert()
if err != nil {
t.Fatalf("unable to create cert: %v", err)
}
defer del()

tests := map[string]struct {
opts []ListenerOption
scheme string
expectedErr bool
}{
"https scheme no TLSInfo": {
opts: []ListenerOption{},
expectedErr: true,
scheme: "https",
},
"https scheme no TLSInfo with skip check": {
opts: []ListenerOption{WithSkipTLSInfoCheck(true)},
expectedErr: false,
scheme: "https",
},
"https scheme empty TLSInfo with skip check": {
opts: []ListenerOption{
WithSkipTLSInfoCheck(true),
WithTLSInfo(&TLSInfo{}),
},
expectedErr: false,
scheme: "https",
},
"https scheme empty TLSInfo no skip check": {
opts: []ListenerOption{
WithTLSInfo(&TLSInfo{}),
},
expectedErr: true,
scheme: "https",
},
"https scheme with TLSInfo and skip check": {
opts: []ListenerOption{
WithSkipTLSInfoCheck(true),
WithTLSInfo(tlsInfo),
},
expectedErr: false,
scheme: "https",
},
}
for testName, test := range tests {
t.Run(testName, func(t *testing.T) {
ln, err := NewListenerWithOpts("127.0.0.1:0", test.scheme, test.opts...)
if ln != nil {
defer ln.Close()
}
if test.expectedErr && err == nil {
t.Fatalf("expected error")
}
if !test.expectedErr && err != nil {
t.Fatalf("unexpected error: %v", err)
}
})
}
}

func TestNewListenerWithSocketOpts(t *testing.T) {
tlsInfo, del, err := createSelfCert()
if err != nil {
t.Fatalf("unable to create cert: %v", err)
}
defer del()

tests := map[string]struct {
opts []ListenerOption
scheme string
expectedErr bool
}{
"nil socketopts": {
opts: []ListenerOption{WithSocketOpts(nil)},
expectedErr: true,
scheme: "http",
},
"empty socketopts": {
opts: []ListenerOption{WithSocketOpts(&SocketOpts{})},
expectedErr: true,
scheme: "http",
},

"reuse address": {
opts: []ListenerOption{WithSocketOpts(&SocketOpts{ReuseAddress: true})},
scheme: "http",
expectedErr: true,
},
"reuse address with TLS": {
opts: []ListenerOption{
WithSocketOpts(&SocketOpts{ReuseAddress: true}),
WithTLSInfo(tlsInfo),
},
scheme: "https",
expectedErr: true,
},
"reuse address and port": {
opts: []ListenerOption{WithSocketOpts(&SocketOpts{ReuseAddress: true, ReusePort: true})},
scheme: "http",
expectedErr: false,
},
"reuse address and port with TLS": {
opts: []ListenerOption{
WithSocketOpts(&SocketOpts{ReuseAddress: true, ReusePort: true}),
WithTLSInfo(tlsInfo),
},
scheme: "https",
expectedErr: false,
},
"reuse port with TLS and timeout": {
opts: []ListenerOption{
WithSocketOpts(&SocketOpts{ReusePort: true}),
WithTLSInfo(tlsInfo),
WithTimeout(5*time.Second, 5*time.Second),
},
scheme: "https",
expectedErr: false,
},
"reuse port with https scheme and no TLSInfo skip check": {
opts: []ListenerOption{
WithSocketOpts(&SocketOpts{ReusePort: true}),
WithSkipTLSInfoCheck(true),
},
scheme: "https",
expectedErr: false,
},
"reuse port": {
opts: []ListenerOption{WithSocketOpts(&SocketOpts{ReusePort: true})},
scheme: "http",
expectedErr: false,
},
}
for testName, test := range tests {
t.Run(testName, func(t *testing.T) {
ln, err := NewListenerWithOpts("127.0.0.1:0", test.scheme, test.opts...)
if err != nil {
t.Fatalf("unexpected NewListenerWithSocketOpts error: %v", err)
}
defer ln.Close()
ln2, err := NewListenerWithOpts(ln.Addr().String(), test.scheme, test.opts...)
if ln2 != nil {
ln2.Close()
}
if test.expectedErr && err == nil {
t.Fatalf("expected error")
}
if !test.expectedErr && err != nil {
t.Fatalf("unexpected error: %v", err)
}
})
}
}

func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) {
ln, err := NewListener("127.0.0.1:0", "https", &tlsInfo)
if err != nil {
Expand Down Expand Up @@ -401,3 +556,21 @@ func TestIsClosedConnError(t *testing.T) {
t.Fatalf("expect true, got false (%v)", err)
}
}

func TestSocktOptsEmpty(t *testing.T) {
tests := []struct {
sopts SocketOpts
want bool
}{
{SocketOpts{}, true},
{SocketOpts{ReuseAddress: true, ReusePort: false}, false},
{SocketOpts{ReusePort: true}, false},
}

for i, tt := range tests {
got := tt.sopts.Empty()
if tt.want != got {
t.Errorf("#%d: result of Empty() incorrect: want=%t got=%t", i, tt.want, got)
}
}
}
Loading