Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add silent mode to rpc server and rpc tracker #1268

Merged
merged 2 commits into from
Jun 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions python/tvm/contrib/rpc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def random_key(prefix, cmap=None):
return prefix + str(random.random())


def connect_with_retry(addr, timeout=60, retry_period=5):
def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
"""Connect to a TPC address with retry

This function is only reliable to short period of server restart.
Expand All @@ -135,6 +135,9 @@ def connect_with_retry(addr, timeout=60, retry_period=5):

retry_period : float
Number of seconds before we retry again.

silent: bool
whether run in silent mode
"""
tstart = time.time()
while True:
Expand All @@ -149,8 +152,9 @@ def connect_with_retry(addr, timeout=60, retry_period=5):
if period > timeout:
raise RuntimeError(
"Failed to connect to server %s" % str(addr))
logging.info("Cannot connect to tracker%s, retry in %g secs...",
str(addr), retry_period)
if not silent:
logging.info("Cannot connect to tracker%s, retry in %g secs...",
str(addr), retry_period)
time.sleep(retry_period)


Expand Down
2 changes: 1 addition & 1 deletion python/tvm/contrib/rpc/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def _fsend(data):
def _connect(key):
conn = yield websocket.websocket_connect(url)
on_message = create_on_message(conn)
temp = _server_env(None)
temp = _server_env(None, None)
# Start connecton
conn.write_message(struct.pack('<i', base.RPC_MAGIC), binary=True)
key = "server:" + key
Expand Down
86 changes: 59 additions & 27 deletions python/tvm/contrib/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import multiprocessing
import subprocess
import time
import sys

from ..._ffi.function import register_func
from ..._ffi.base import py_str
Expand All @@ -28,9 +29,12 @@
from . import base
from . base import TrackerCode

def _server_env(load_library):
def _server_env(load_library, logger):
"""Server environment function return temp dir"""
temp = util.tempdir()
if logger is None:
logger = logging.getLogger()

# pylint: disable=unused-variable
@register_func("tvm.contrib.rpc.server.workpath")
def get_workpath(path):
Expand All @@ -41,26 +45,29 @@ def load_module(file_name):
"""Load module from remote side."""
path = temp.relpath(file_name)
m = _load_module(path)
logging.info("load_module %s", path)
logger.info("load_module %s", path)
return m

libs = []
load_library = load_library.split(":") if load_library else []
for file_name in load_library:
file_name = find_lib_path(file_name)[0]
libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL))
logging.info("Load additional library %s", file_name)
logger.info("Load additional library %s", file_name)
temp.libs = libs
return temp


def _serve_loop(sock, addr, load_library):
def _serve_loop(sock, addr, load_library, silent):
"""Server loop"""
logger = logging.getLogger("RPCServer")
if silent:
logger.disabled = True
sockfd = sock.fileno()
temp = _server_env(load_library)
temp = _server_env(load_library, logger)
base._ServerLoop(sockfd)
temp.remove()
logging.info("Finish serving %s", addr)
logger.info("Finish serving %s", addr)


def _parse_server_opt(opts):
Expand All @@ -71,8 +78,12 @@ def _parse_server_opt(opts):
ret["timeout"] = float(kv[9:])
return ret

def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
"""Lisenting loop of the server master."""
def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, silent):
"""Listening loop of the server master."""
logger = logging.getLogger("RPCServer")
if silent:
logger.disabled = True

def _accept_conn(listen_sock, tracker_conn, ping_period=2):
"""Accept connection from the other places.

Expand Down Expand Up @@ -115,7 +126,7 @@ def _accept_conn(listen_sock, tracker_conn, ping_period=2):
unmatch_period_count = 0
# regenerate match key if key is acquired but not used for a while
if unmatch_period_count * ping_period > unmatch_timeout + ping_period:
logging.info("RPCServer: no incoming connections, regenerate key ...")
logger.info("no incoming connections, regenerate key ...")
matchkey = base.random_key(rpc_key + ":", old_keyset)
base.sendjson(tracker_conn,
[TrackerCode.PUT, rpc_key, (port, matchkey),
Expand All @@ -136,7 +147,7 @@ def _accept_conn(listen_sock, tracker_conn, ping_period=2):
if arr[0] != expect_header:
conn.sendall(struct.pack("<i", base.RPC_CODE_MISMATCH))
conn.close()
logging.info("RPCServer: mismatch key from %s", addr)
logger.info("mismatch key from %s", addr)
continue
else:
conn.sendall(struct.pack("<i", base.RPC_CODE_SUCCESS))
Expand All @@ -150,7 +161,7 @@ def _accept_conn(listen_sock, tracker_conn, ping_period=2):
try:
# step 1: setup tracker and report to tracker
if tracker_addr and tracker_conn is None:
tracker_conn = base.connect_with_retry(tracker_addr)
tracker_conn = base.connect_with_retry(tracker_addr, silent=silent)
tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("<i", base.recvall(tracker_conn, 4))[0]
if magic != base.RPC_TRACKER_MAGIC:
Expand All @@ -169,22 +180,31 @@ def _accept_conn(listen_sock, tracker_conn, ping_period=2):
tracker_conn.close()
tracker_conn = None
continue
except RuntimeError as exc:
if silent:
return
else:
raise exc

# step 3: serving
logging.info("RPCServer: connection from %s", addr)
server_proc = multiprocessing.Process(target=_serve_loop, args=(conn, addr, load_library))
logger.info("connection from %s", addr)
server_proc = multiprocessing.Process(target=_serve_loop,
args=(conn, addr, load_library, silent))
server_proc.deamon = True
server_proc.start()
# close from our side.
conn.close()
# wait until server process finish or timeout
server_proc.join(opts.get("timeout", None))
if server_proc.is_alive():
logging.info("RPCServer: Timeout in RPC session, kill..")
logger.info("Timeout in RPC session, kill..")
server_proc.terminate()


def _connect_proxy_loop(addr, key, load_library):
def _connect_proxy_loop(addr, key, load_library, silent):
logger = logging.getLogger("RPCProxy")
if silent:
logger.disabled = True
key = "server:" + key
retry_count = 0
max_retry = 5
Expand All @@ -200,26 +220,26 @@ def _connect_proxy_loop(addr, key, load_library):
if magic == base.RPC_CODE_DUPLICATE:
raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == base.RPC_CODE_MISMATCH:
logging.info("RPCProxy do not have matching client key %s", key)
logger.info("RPCProxy do not have matching client key %s", key)
elif magic != base.RPC_CODE_SUCCESS:
raise RuntimeError("%s is not RPC Proxy" % str(addr))
keylen = struct.unpack("<i", base.recvall(sock, 4))[0]
remote_key = py_str(base.recvall(sock, keylen))
opts = _parse_server_opt(remote_key.split()[1:])
logging.info("RPCProxy connected to %s", str(addr))
logger.info("connected to %s", str(addr))
process = multiprocessing.Process(
target=_serve_loop, args=(sock, addr, load_library))
target=_serve_loop, args=(sock, addr, load_library, silent))
process.deamon = True
process.start()
sock.close()
process.join(opts.get("timeout", None))
if process.is_alive():
logging.info("RPCProxyServer: Timeout in RPC session, kill..")
logger.info("Timeout in RPC session, kill..")
process.terminate()
retry_count = 0
except (socket.error, IOError) as err:
retry_count += 1
logging.info("Error encountered %s, retry in %g sec", str(err), retry_period)
logger.info("Error encountered %s, retry in %g sec", str(err), retry_period)
if retry_count > max_retry:
raise RuntimeError("Maximum retry error: last error: %s" % str(err))
time.sleep(retry_period)
Expand Down Expand Up @@ -264,6 +284,9 @@ class Server(object):
This is recommended to switch on if we want to do local RPC demonstration
for GPU devices to avoid fork safety issues.

silent: bool, optional
Whether run this server in silent mode.

key : str, optional
The key used to identify the server in Proxy connection.

Expand All @@ -276,6 +299,7 @@ def __init__(self,
port_end=9199,
is_proxy=False,
use_popen=False,
silent=False,
tracker_addr=None,
key="",
load_library=None,
Expand All @@ -290,8 +314,12 @@ def __init__(self,
self.libs = []
self.custom_addr = custom_addr

self.logger = logging.getLogger("RPCServer")
if silent:
self.logger.disabled = True

if use_popen:
cmd = ["python",
cmd = [sys.executable,
"-m", "tvm.exec.rpc_server",
"--host=%s" % host,
"--port=%s" % port]
Expand All @@ -303,11 +331,14 @@ def __init__(self,
cmd += ["--load-library", load_library]
if custom_addr:
cmd += ["--custom-addr", custom_addr]
if silent:
cmd += ["--silent"]

self.proc = multiprocessing.Process(
target=subprocess.check_call, args=(cmd,))
self.proc.deamon = True
self.proc.start()
time.sleep(1)
time.sleep(0.5)
elif not is_proxy:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = None
Expand All @@ -321,19 +352,20 @@ def __init__(self,
continue
else:
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logging.info("RPCServer: bind to %s:%d", host, self.port)
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
self.logger.info("bind to %s:%d", host, self.port)
sock.listen(1)
self.sock = sock
self.proc = multiprocessing.Process(
target=_listen_loop, args=(
self.sock, self.port, key, tracker_addr, load_library, self.custom_addr))
self.sock, self.port, key, tracker_addr, load_library,
self.custom_addr, silent))
self.proc.deamon = True
self.proc.start()
else:
self.proc = multiprocessing.Process(
target=_connect_proxy_loop, args=((host, port), key, load_library))
target=_connect_proxy_loop, args=((host, port), key, load_library, silent))
self.proc.deamon = True
self.proc.start()

Expand Down
15 changes: 11 additions & 4 deletions python/tvm/contrib/rpc/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,6 @@ def run(self):
def _tracker_server(listen_sock, stop_key):
handler = TrackerServerHandler(listen_sock, stop_key)
handler.run()
logging.info("Tracker Stop signal received, terminating...")


class Tracker(object):
Expand All @@ -327,11 +326,19 @@ class Tracker(object):

port_end : int, optional
The end TCP port to search

silent: bool, optional
Whether run in silent mode
"""
def __init__(self,
host,
port=9190,
port_end=9199):
port_end=9199,
silent=False):
self.logger = logging.getLogger("RPCTracker")
if silent:
self.logger.disabled = True

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = None
self.stop_key = base.random_key("tracker")
Expand All @@ -347,7 +354,7 @@ def __init__(self,
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logging.info("RPCTracker: bind to %s:%d", host, self.port)
self.logger.info("bind to %s:%d", host, self.port)
sock.listen(1)
self.proc = multiprocessing.Process(
target=_tracker_server, args=(sock, self.stop_key))
Expand All @@ -373,7 +380,7 @@ def terminate(self):
self._stop_tracker()
self.proc.join(1)
if self.proc.is_alive():
logging.info("Terminating Tracker Server...")
self.logger.info("Terminating Tracker Server...")
self.proc.terminate()
self.proc = None

Expand Down
10 changes: 7 additions & 3 deletions python/tvm/exec/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def main(args):
key=args.key,
tracker_addr=tracker_addr,
load_library=args.load_library,
custom_addr=args.custom_addr)
custom_addr=args.custom_addr,
silent=args.silent)
server.proc.join()


Expand All @@ -51,6 +52,8 @@ def main(args):
and ROCM compilers.")
parser.add_argument('--custom-addr', type=str,
help="Custom IP Address to Report to RPC Tracker")
parser.add_argument('--silent', action='store_true',
help="Whether run in silent mode.")

parser.set_defaults(fork=True)
args = parser.parse_args()
Expand All @@ -62,6 +65,7 @@ def main(args):
)
multiprocessing.set_start_method('spawn')
else:
logging.info("If you are running ROCM/Metal, \
fork with cause compiler internal error. Try to launch with arg ```--no-fork```")
if not args.silent:
logging.info("If you are running ROCM/Metal, fork will cause "
"compiler internal error. Try to launch with arg ```--no-fork```")
main(args)
13 changes: 10 additions & 3 deletions python/tvm/exec/rpc_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

def main(args):
"""Main funciton"""
tracker = Tracker(args.host, port=args.port)
tracker = Tracker(args.host, port=args.port, port_end=args.port_end,
silent=args.silent)
tracker.proc.join()


Expand All @@ -21,10 +22,15 @@ def main(args):
help='the hostname of the tracker')
parser.add_argument('--port', type=int, default=9190,
help='The port of the PRC')
parser.add_argument('--port-end', type=int, default=9199,
help='The end search port of the PRC')
parser.add_argument('--no-fork', dest='fork', action='store_false',
help="Use spawn mode to avoid fork. This option \
is able to avoid potential fork problems with Metal, OpenCL \
and ROCM compilers.")
parser.add_argument('--silent', action='store_true',
help="Whether run in silent mode.")

parser.set_defaults(fork=True)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
Expand All @@ -35,6 +41,7 @@ def main(args):
)
multiprocessing.set_start_method('spawn')
else:
logging.info("If you are running ROCM/Metal, \
fork with cause compiler internal error. Try to launch with arg ```--no-fork```")
if not args.silent:
logging.info("If you are running ROCM/Metal, fork will cause "
"compiler internal error. Try to launch with arg ```--no-fork```")
main(args)
Loading