Skip to content

Commit

Permalink
chore: add loopback detect for direct outbound
Browse files Browse the repository at this point in the history
  • Loading branch information
wwqgtxx committed Dec 20, 2023
1 parent 518c31d commit d4bb4ed
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 4 deletions.
15 changes: 13 additions & 2 deletions adapter/outbound/direct.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package outbound
import (
"context"
"errors"
"fmt"
"net/netip"

N "github.com/metacubex/mihomo/common/net"
Expand All @@ -13,6 +14,7 @@ import (

type Direct struct {
*Base
loopBack *loopBackDetector
}

type DirectOption struct {
Expand All @@ -22,17 +24,23 @@ type DirectOption struct {

// DialContext implements C.ProxyAdapter
func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
if d.loopBack.CheckConn(metadata.SourceAddrPort()) {
return nil, fmt.Errorf("reject loopback connection to: %s", metadata.RemoteAddress())
}
opts = append(opts, dialer.WithResolver(resolver.DefaultResolver))
c, err := dialer.DialContext(ctx, "tcp", metadata.RemoteAddress(), d.Base.DialOptions(opts...)...)
if err != nil {
return nil, err
}
N.TCPKeepAlive(c)
return NewConn(c, d), nil
return d.loopBack.NewConn(NewConn(c, d)), nil
}

// ListenPacketContext implements C.ProxyAdapter
func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
if d.loopBack.CheckPacketConn(metadata.SourceAddrPort()) {
return nil, fmt.Errorf("reject loopback connection to: %s", metadata.RemoteAddress())
}
// net.UDPConn.WriteTo only working with *net.UDPAddr, so we need a net.UDPAddr
if !metadata.Resolved() {
ip, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, resolver.DefaultResolver)
Expand All @@ -45,7 +53,7 @@ func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata,
if err != nil {
return nil, err
}
return newPacketConn(pc, d), nil
return d.loopBack.NewPacketConn(newPacketConn(pc, d)), nil
}

func NewDirectWithOption(option DirectOption) *Direct {
Expand All @@ -60,6 +68,7 @@ func NewDirectWithOption(option DirectOption) *Direct {
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
},
loopBack: newLoopBackDetector(),
}
}

Expand All @@ -71,6 +80,7 @@ func NewDirect() *Direct {
udp: true,
prefer: C.DualStack,
},
loopBack: newLoopBackDetector(),
}
}

Expand All @@ -82,5 +92,6 @@ func NewCompatible() *Direct {
udp: true,
prefer: C.DualStack,
},
loopBack: newLoopBackDetector(),
}
}
68 changes: 68 additions & 0 deletions adapter/outbound/direct_loopback_detect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package outbound

import (
"net/netip"

"github.com/metacubex/mihomo/common/callback"
C "github.com/metacubex/mihomo/constant"

"github.com/puzpuzpuz/xsync/v3"
)

type loopBackDetector struct {
connMap *xsync.MapOf[netip.AddrPort, struct{}]
packetConnMap *xsync.MapOf[netip.AddrPort, struct{}]
}

func newLoopBackDetector() *loopBackDetector {
return &loopBackDetector{
connMap: xsync.NewMapOf[netip.AddrPort, struct{}](),
packetConnMap: xsync.NewMapOf[netip.AddrPort, struct{}](),
}
}

func (l *loopBackDetector) NewConn(conn C.Conn) C.Conn {
metadata := C.Metadata{}
if metadata.SetRemoteAddr(conn.LocalAddr()) != nil {
return conn
}
connAddr := metadata.AddrPort()
if !connAddr.IsValid() {
return conn
}
l.connMap.Store(connAddr, struct{}{})
return callback.NewCloseCallbackConn(conn, func() {
l.packetConnMap.Delete(connAddr)
})
}

func (l *loopBackDetector) NewPacketConn(conn C.PacketConn) C.PacketConn {
metadata := C.Metadata{}
if metadata.SetRemoteAddr(conn.LocalAddr()) != nil {
return conn
}
connAddr := metadata.AddrPort()
if !connAddr.IsValid() {
return conn
}
l.packetConnMap.Store(connAddr, struct{}{})
return callback.NewCloseCallbackPacketConn(conn, func() {
l.packetConnMap.Delete(connAddr)
})
}

func (l *loopBackDetector) CheckConn(connAddr netip.AddrPort) bool {
if !connAddr.IsValid() {
return false
}
_, ok := l.connMap.Load(connAddr)
return ok
}

func (l *loopBackDetector) CheckPacketConn(connAddr netip.AddrPort) bool {
if !connAddr.IsValid() {
return false
}
_, ok := l.packetConnMap.Load(connAddr)
return ok
}
61 changes: 61 additions & 0 deletions common/callback/close_callback.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package callback

import (
"sync"

C "github.com/metacubex/mihomo/constant"
)

type closeCallbackConn struct {
C.Conn
closeFunc func()
closeOnce sync.Once
}

func (w *closeCallbackConn) Close() error {
w.closeOnce.Do(w.closeFunc)
return w.Conn.Close()
}

func (w *closeCallbackConn) ReaderReplaceable() bool {
return true
}

func (w *closeCallbackConn) WriterReplaceable() bool {
return true
}

func (w *closeCallbackConn) Upstream() any {
return w.Conn
}

func NewCloseCallbackConn(conn C.Conn, callback func()) C.Conn {
return &closeCallbackConn{Conn: conn, closeFunc: callback}
}

type closeCallbackPacketConn struct {
C.PacketConn
closeFunc func()
closeOnce sync.Once
}

func (w *closeCallbackPacketConn) Close() error {
w.closeOnce.Do(w.closeFunc)
return w.PacketConn.Close()
}

func (w *closeCallbackPacketConn) ReaderReplaceable() bool {
return true
}

func (w *closeCallbackPacketConn) WriterReplaceable() bool {
return true
}

func (w *closeCallbackPacketConn) Upstream() any {
return w.PacketConn
}

func NewCloseCallbackPacketConn(conn C.PacketConn, callback func()) C.PacketConn {
return &closeCallbackPacketConn{PacketConn: conn, closeFunc: callback}
}
8 changes: 6 additions & 2 deletions constant/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ type Metadata struct {
SpecialRules string `json:"specialRules"`
RemoteDst string `json:"remoteDestination"`

RawSrcAddr net.Addr `json:"-"`
RawDstAddr net.Addr `json:"-"`
RawSrcAddr net.Addr `json:"-"`
RawDstAddr net.Addr `json:"-"`
// Only domain rule
SniffHost string `json:"sniffHost"`
}
Expand All @@ -162,6 +162,10 @@ func (m *Metadata) SourceAddress() string {
return net.JoinHostPort(m.SrcIP.String(), strconv.FormatUint(uint64(m.SrcPort), 10))
}

func (m *Metadata) SourceAddrPort() netip.AddrPort {
return netip.AddrPortFrom(m.SrcIP.Unmap(), m.SrcPort)
}

func (m *Metadata) SourceDetail() string {
if m.Type == INNER {
return fmt.Sprintf("%s", MihomoName)
Expand Down

0 comments on commit d4bb4ed

Please sign in to comment.