Skip to content

Commit

Permalink
Merge pull request #2 from Snawoot/routing
Browse files Browse the repository at this point in the history
nfmark and resolve-once features
  • Loading branch information
Snawoot committed Feb 22, 2020
2 parents fabc021 + e8b7ace commit c2abd53
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 18 deletions.
59 changes: 58 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ uotp-server -c /etc/letsencrypt/live/example.com/fullchain.pem \
127.0.0.1 26611
```

where 26611 is a Wireguard UDP port. By default server accepts connections on port 8443.
where 26611 is a target UDP service port. By default server accepts connections on port 8443.

Client example:

Expand All @@ -45,6 +45,63 @@ where `0.0.0.0` is a listen address (default is localhost only) and `example.com

See Synopsis for more options.

## Using as a transport for VPN

This application can be used as a transport for UDP-based VPN like Wireguard or OpenVPN.

In case when udp-over-tls-pool server address is covered by routing prefixes tunneled through VPN (for example, if VPN replaces default gateway), udp-over-tls-pool traffic must be excluded. Otherwise connections from uotp-client to uotp-server will be looped back to tunnel. There are at least two ways to resolve that loop.

### Excluding uotp-client traffic with a static route

Classic solution is to define specific route to host with udp-over-tls-pool server. Here is an example Wireguard configuration for Linux:

```
[Interface]
Address = 172.21.123.2/32
PrivateKey = XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
PreUp = ip route add 198.51.100.1/32 $(ip route show default | cut -f2- -d\ )
PostDown = ip route del 198.51.100.1/32
[Peer]
PublicKey = YYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY
Endpoint = 127.0.0.1:8911
AllowedIPs = 0.0.0.0/0
```

where `198.51.100.1` is an IP address of host with uotp-server.

Such solution should work on all platforms and operating systems, though it leaves all other traffic to uotp-server host unprotected.

### Excluding uotp-client traffic with rule-based routing

Some VPN tunnels use rule-based routing on Linux to exclude own packets from tunnel itself. For example, Wireguard started with `wg-quick` command uses netfilter mark to distinguish tunnel carrier packets. uotp-client is capable to mark own TCP/TLS packets with nfmark as well. To enable this feature you may run uotp-client like this:

```
uotp-client --resolve-once --mark 0xca6c example.com 8443
```

where `0xca6c` is default fwmark for Wireguard set by `wg-quick`. You may check this value with `wg show INTERFACE fwmark`. Once this is enabled no additional for Wireguard configuration is required.

Note that to use netfilter marks uotp-client has to be run as superuser or process has to be started with `CAP_NET_ADMIN` capability. You may set this capability for a process running as restricted user with systemd service file like one below:

```
# /etc/systemd/system/uotp-client.service
[Unit]
Description=UDP over TLS pool client
After=syslog.target network.target
[Service]
Type=notify
User=uotp-client
AmbientCapabilities=CAP_NET_ADMIN
ExecStart=/usr/local/bin/uotp-client --resolve-once --mark 0xca6c example.com 8443
Restart=always
KillMode=process
[Install]
WantedBy=multi-user.target
```

## Synopsis

Server:
Expand Down
17 changes: 15 additions & 2 deletions udp_over_tls_pool/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def parse_args():
parser.add_argument("-l", "--logfile",
help="log file location",
metavar="FILE")
parser.add_argument("-R", "--resolve-once",
help="resolve destination address at startup "
"(useful for VPN transport when DNS becomes "
"unavailable directly)",
action="store_true")

listen_group = parser.add_argument_group('listen options')
listen_group.add_argument("-a", "--bind-address",
Expand Down Expand Up @@ -59,6 +64,10 @@ def parse_args():
default=4.,
type=utils.check_positive_float,
help="server connect timeout in seconds")
pool_group.add_argument("-m", "--mark",
type=utils.check_fwmark,
help="(Linux only) set nfmark for outbound TCP "
"connections")

tls_group = parser.add_argument_group('TLS options')
tls_group.add_argument("--no-tls",
Expand Down Expand Up @@ -101,14 +110,18 @@ async def amain(args, loop): # pragma: no cover
ssl_hostname = ''
elif args.tls_servername:
ssl_hostname = args.tls_servername
else:
ssl_hostname = args.dst_address
if args.cert:
context.load_cert_chain(certfile=args.cert, keyfile=args.key)
else:
context = None

conn_factory = lambda sess_id, recv_cb, queue: upstream.UpstreamConnection(args.dst_address,
dst_host = (utils.resolve_tcp_endpoint(args.dst_address, args.dst_port)
if args.resolve_once else args.dst_address)
conn_factory = lambda sess_id, recv_cb, queue: upstream.UpstreamConnection(dst_host,
args.dst_port, context, ssl_hostname, sess_id, recv_cb, queue,
timeout=args.timeout, backoff=args.backoff)
timeout=args.timeout, backoff=args.backoff, mark=args.mark)
session_factory = lambda recv_cb: client_session.ClientSession(conn_factory,
recv_cb,
pool_size=args.pool_size)
Expand Down
37 changes: 33 additions & 4 deletions udp_over_tls_pool/upstream.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,38 @@
import asyncio
import logging
import socket

from .constants import LEN_FORMAT, LEN_BYTES

async def open_custom_connection(host, port, *, mark=None, **kwds):
loop = asyncio.get_event_loop()
addr_infos = await loop.getaddrinfo(host, port,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP)
my_exc = None
for ai in addr_infos:
try:
fam, typ, proto, cname, addr = ai
sock = socket.socket(fam, typ, proto)
if mark is not None:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_MARK, mark)
sock.setblocking(False)
await loop.sock_connect(sock, addr)
break
except OSError as exc:
my_exc = exc
else:
raise my_exc

new_kwds = dict(kwds)
new_kwds['sock'] = sock
if kwds.get('ssl') is not None and kwds.get('server_hostname') is None:
new_kwds['server_hostname'] = host
return await asyncio.open_connection(**new_kwds)

class UpstreamConnection:
def __init__(self, host, port, ssl_ctx, server_name, sess_id, recv_cb,
queue, *, timeout=4, backoff=5):
queue, *, timeout=4, backoff=5, mark=None):
self._host = host
self._port = port
self._ssl_ctx = ssl_ctx
Expand All @@ -15,6 +42,7 @@ def __init__(self, host, port, ssl_ctx, server_name, sess_id, recv_cb,
self._queue = queue
self._timeout = timeout
self._backoff = backoff
self._mark = mark
self._logger = logging.getLogger(self.__class__.__name__)
self._worker_task = asyncio.ensure_future(self._worker())
self._logger.debug("Connection 0x%x for session %s started",
Expand Down Expand Up @@ -53,9 +81,10 @@ async def _worker(self):
try:
try:
reader, writer = await asyncio.wait_for(
asyncio.open_connection(self._host, self._port,
ssl=self._ssl_ctx,
server_hostname=self._server_name),
open_custom_connection(self._host, self._port,
ssl=self._ssl_ctx,
server_hostname=self._server_name,
mark=self._mark),
self._timeout)
except asyncio.TimeoutError:
self._logger.warning("Connection 0x%x for session %s: "
Expand Down
33 changes: 22 additions & 11 deletions udp_over_tls_pool/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import os
import queue
import socket
import ctypes
import time

from . import constants

Expand Down Expand Up @@ -92,6 +90,19 @@ def fail():
return fvalue


def check_fwmark(value):
def fail():
raise argparse.ArgumentTypeError(
"%s is not a valid value" % value)
try:
ivalue = int(value, 0)
except ValueError:
fail()
if not (0 <= ivalue < 1<<32):
fail()
return ivalue


def check_loglevel(arg):
try:
return constants.LogLevel[arg]
Expand Down Expand Up @@ -151,13 +162,13 @@ async def stop(self):
pass


async def wall_clock_sleep(duration, precision=.2):
async def _wall_clock_sleep():
end_time = time.time() + duration
while time.time() < end_time:
await asyncio.sleep(precision)
AF_PREFERENCE = {
socket.AF_INET: 1,
socket.AF_INET6: 2,
}

try:
await asyncio.wait_for(_wall_clock_sleep(), duration)
except asyncio.TimeoutError:
pass

def resolve_tcp_endpoint(host, port=None):
res = socket.getaddrinfo(host, port, proto=socket.IPPROTO_TCP)
res = sorted(res, key=lambda v: AF_PREFERENCE.get(v[0], 100))
return res[0][4][0]

0 comments on commit c2abd53

Please sign in to comment.