diff --git a/internal/security/dialFunc.go b/internal/security/dialFunc.go new file mode 100644 index 0000000..35512c7 --- /dev/null +++ b/internal/security/dialFunc.go @@ -0,0 +1,65 @@ +package security + +import ( + "crypto/tls" + "errors" + "net" + "time" +) + +var errIpNotAllowed error = errors.New("ip adress is not allowed") + +// IsDisallowedIP checks if the provided host is a disallowed IP address. +// It parses the given host into an IP address and returns true if the IP is multicast, +// unspecified, a loopback address, or a private address. +func IsDisallowedIP(host string) bool { + ip := net.ParseIP(host) + + return ip.IsMulticast() || ip.IsUnspecified() || ip.IsLoopback() || ip.IsPrivate() +} + +// checkDisallowedIP checks if the IP address of the incoming connection is disallowed. +func checkDisallowedIP(conn net.Conn) error { + ip, _, _ := net.SplitHostPort(conn.RemoteAddr().String()) + + if IsDisallowedIP(ip) { + conn.Close() + return errIpNotAllowed + } + + return nil +} + +// dialFunc establishes a network connection to a specified address with optional TLS configuration and timeout. +// It first checks if a TLS configuration is provided, if so, it dials with TLS using the provided TLS configuration. +// If not, it dials without TLS. After establishing the connection, it checks if the remote IP is disallowed. +// Returns the connection and any error encountered during the process. +func dialFunc(network string, addr string, timeout time.Duration, tlsConfig *tls.Config) (net.Conn, error) { + dialer := &net.Dialer{ + Timeout: timeout, + } + + if tlsConfig != nil { + conn, err := tls.DialWithDialer(dialer, network, addr, tlsConfig) + if err != nil { + return nil, err + } + + if err := checkDisallowedIP(conn); err != nil { + return nil, err + } + + return conn, err + } + + conn, err := dialer.Dial(network, addr) + if err != nil { + return nil, err + } + + if err := checkDisallowedIP(conn); err != nil { + return nil, err + } + + return conn, err +}