Skip to content

Commit

Permalink
Add parameter to make default context
Browse files Browse the repository at this point in the history
If nil, context.WithTimeout(context.Background(), 30*time.Second)
will be used.
  • Loading branch information
at-wat committed Feb 11, 2020
1 parent 38f23fd commit 06468a3
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 14 deletions.
18 changes: 16 additions & 2 deletions config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dtls

import (
"context"
"crypto/ecdsa"
"crypto/tls"
"crypto/x509"
Expand Down Expand Up @@ -83,13 +84,26 @@ type Config struct {

LoggerFactory logging.LoggerFactory

// function to make a context used in Dial(), Client(), Server(), and Accept().
// If nil, DefaultConnectContextMaker is used.
ConnectContextMaker func() (context.Context, func())

// MTU is the length at which handshake messages will be fragmented to
// fit within the maximum transmission unit (default is 1200 bytes)
MTU int
}

// DefaultConnectTimeout is a timeout duration used in Dial(), Client() and Server().
const DefaultConnectTimeout = 30 * time.Second
// DefaultConnectContextMaker is a default ConnectContextMaker used in Dial(), Client(), Server(), and Accept().
var DefaultConnectContextMaker = func() (context.Context, func()) {
return context.WithTimeout(context.Background(), 30*time.Second)
}

func (c *Config) connectContextMaker() (context.Context, func()) {
if c.ConnectContextMaker == nil {
return DefaultConnectContextMaker()
}
return c.ConnectContextMaker()
}

const defaultMTU = 1200 // bytes

Expand Down
12 changes: 6 additions & 6 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,30 +216,30 @@ func createConn(ctx context.Context, nextConn net.Conn, flightHandler flightHand
}

// Dial connects to the given network address and establishes a DTLS connection on top.
// Connection handshake will timeout in DefaultConnectTimeout.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, use DialWithContext() instead.
func Dial(network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), DefaultConnectTimeout)
ctx, cancel := config.connectContextMaker()
defer cancel()

return DialWithContext(ctx, network, raddr, config)
}

// Client establishes a DTLS connection over an existing connection.
// Connection handshake will timeout in DefaultConnectTimeout.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, use ClientWithContext() instead.
func Client(conn net.Conn, config *Config) (*Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), DefaultConnectTimeout)
ctx, cancel := config.connectContextMaker()
defer cancel()

return ClientWithContext(ctx, conn, config)
}

// Server listens for incoming DTLS connections.
// Connection handshake will timeout in DefaultConnectTimeout.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, use ServerWithContext() instead.
func Server(conn net.Conn, config *Config) (*Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), DefaultConnectTimeout)
ctx, cancel := config.connectContextMaker()
defer cancel()

return ServerWithContext(ctx, conn, config)
Expand Down
10 changes: 10 additions & 0 deletions examples/listen-psk/main.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package main

import (
"context"
"fmt"
"net"
"time"

"github.com/pion/dtls/v2"
"github.com/pion/dtls/v2/examples/util"
Expand All @@ -12,6 +14,10 @@ func main() {
// Prepare the IP to connect to
addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444}

// Create parent context to cleanup handshaking connections on exit.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

//
// Everything below is the pion-DTLS API! Thanks for using it ❤️.
//
Expand All @@ -25,6 +31,10 @@ func main() {
PSKIdentityHint: []byte("Pion DTLS Client"),
CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8},
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
// Create timeout context for accepted connection.
ConnectContextMaker: func() (context.Context, func()) {
return context.WithTimeout(ctx, 30*time.Second)
},
}

// Connect to a DTLS server
Expand Down
10 changes: 10 additions & 0 deletions examples/listen/main.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package main

import (
"context"
"crypto/tls"
"fmt"
"net"
"time"

"github.com/pion/dtls/v2"
"github.com/pion/dtls/v2/examples/util"
Expand All @@ -18,6 +20,10 @@ func main() {
certificate, genErr := selfsign.GenerateSelfSigned()
util.Check(genErr)

// Create parent context to cleanup handshaking connections on exit.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

//
// Everything below is the pion-DTLS API! Thanks for using it ❤️.
//
Expand All @@ -26,6 +32,10 @@ func main() {
config := &dtls.Config{
Certificates: []tls.Certificate{certificate},
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
// Create timeout context for accepted connection.
ConnectContextMaker: func() (context.Context, func()) {
return context.WithTimeout(ctx, 30*time.Second)
},
}

// Connect to a DTLS server
Expand Down
14 changes: 8 additions & 6 deletions listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
)

// Listen creates a DTLS listener
func Listen(network string, laddr *net.UDPAddr, config *Config) (net.Listener, error) {
func Listen(network string, laddr *net.UDPAddr, config *Config) (*Listener, error) {
if err := validateConfig(config); err != nil {
return nil, err
}
Expand All @@ -16,21 +16,23 @@ func Listen(network string, laddr *net.UDPAddr, config *Config) (net.Listener, e
if err != nil {
return nil, err
}
return &listener{
return &Listener{
config: config,
parent: parent,
}, nil
}

// Listener represents a DTLS listener
type listener struct {
type Listener struct {
config *Config
parent *udp.Listener
}

// Accept waits for and returns the next connection to the listener.
// You have to either close or read on all connection that are created.
func (l *listener) Accept() (net.Conn, error) {
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, set ConnectContextMaker.
func (l *Listener) Accept() (net.Conn, error) {
c, err := l.parent.Accept()
if err != nil {
return nil, err
Expand All @@ -41,11 +43,11 @@ func (l *listener) Accept() (net.Conn, error) {
// Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
// Already Accepted connections are not closed.
func (l *listener) Close() error {
func (l *Listener) Close() error {
return l.parent.Close()
}

// Addr returns the listener's network address.
func (l *listener) Addr() net.Addr {
func (l *Listener) Addr() net.Addr {
return l.parent.Addr()
}

0 comments on commit 06468a3

Please sign in to comment.