Skip to content

Commit

Permalink
Split up /call and /logs.
Browse files Browse the repository at this point in the history
  • Loading branch information
rohinb2 committed Aug 14, 2024
1 parent fe8cb72 commit 7cb3b41
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 90 deletions.
2 changes: 2 additions & 0 deletions runhouse/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
DEFAULT_SERVER_HOST = "0.0.0.0"

LOGGING_WAIT_TIME = 0.5
LOGS_TO_SHOW_UP_CHECK_TIME = 0.1
MAX_LOGS_TO_SHOW_UP_WAIT_TIME = 5

# Commands
SERVER_START_CMD = f"{sys.executable} -m runhouse.servers.http.http_server"
Expand Down
241 changes: 165 additions & 76 deletions runhouse/servers/http/http_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import asyncio
import json
import logging
import time
import warnings

from concurrent.futures import ThreadPoolExecutor
from functools import wraps
from pathlib import Path
from random import randrange
Expand Down Expand Up @@ -36,7 +39,7 @@
serialize_data,
)

from runhouse.utils import generate_default_name
from runhouse.utils import generate_default_name, thread_coroutine


# Make this global so connections are pooled across instances of HTTPClient
Expand Down Expand Up @@ -433,54 +436,59 @@ def call_module_method(
+ (f".{method_name}" if method_name else "")
)
serialization = serialization or "pickle"
res = retry_with_exponential_backoff(session.post)(
self._formatted_url(f"{key}/{method_name}"),
json=CallParams(
data=serialize_data(data, serialization),
serialization=serialization,
run_name=run_name,
stream_logs=stream_logs,
save=save,
remote=remote,
).dict(),
stream=True,
headers=rns_client.request_headers(resource_address),
auth=self.auth,
verify=self.verify,
)

if res.status_code != 200:
raise ValueError(
f"Error calling {method_name} on server: {res.content.decode()}"
)
error_str = f"Error calling {method_name} on {key} on server"

# We get back a stream of intermingled log outputs and results (maybe None, maybe error, maybe single result,
# maybe a stream of results), so we need to separate these out.
result = None
res_iter = res.iter_lines(chunk_size=None)
# We need to manually iterate through res_iter so we can try/except to bypass a ChunkedEncodingError bug
while True:
try:
responses_json = next(res_iter)
except requests.exceptions.ChunkedEncodingError:
# Some silly bug in urllib3, see https://github.com/psf/requests/issues/4248
continue
except StopIteration:
break
except StopAsyncIteration:
break

resp = json.loads(responses_json)
output_type = resp["output_type"]
result = handle_response(
resp, output_type, error_str, log_formatter=self.log_formatter
with ThreadPoolExecutor() as executor:
# Run logs request in separate thread. Can start it before because it'll wait 5 seconds for the
# calls request to begin.
if stream_logs:
logs_future = executor.submit(
thread_coroutine,
self._alogs_request(
run_name=run_name,
serialization=serialization,
error_str=error_str,
resource_address=resource_address,
create_async_client=True,
),
)
response = retry_with_exponential_backoff(session.post)(
self._formatted_url(f"{key}/{method_name}"),
json=CallParams(
data=serialize_data(data, serialization),
serialization=serialization,
run_name=run_name,
stream_logs=stream_logs,
save=save,
remote=remote,
).dict(),
headers=rns_client.request_headers(resource_address),
auth=self.auth,
verify=self.verify,
)

result = self._process_call_result(result, system, output_type)
if response.status_code != 200:
raise ValueError(
f"Error calling {method_name} on server: {response.content.decode()}"
)

resp_json = response.json()
function_result = handle_response(
resp_json,
resp_json["output_type"],
error_str,
log_formatter=self.log_formatter,
)
output_type = resp_json["output_type"]
if logs_future:
_ = logs_future.result()

end = time.time()

function_result = self._process_call_result(
function_result, system, output_type
)

if method_name:
log_str = (
f"Time to call {key}.{method_name}: {round(end - start, 2)} seconds"
Expand All @@ -489,7 +497,7 @@ def call_module_method(
log_str = f"Time to get {key}: {round(end - start, 2)} seconds"

logging.info(log_str)
return result
return function_result

async def acall(
self,
Expand Down Expand Up @@ -519,6 +527,75 @@ async def acall(
system=self.system,
)

async def _acall_request(
self,
key: str,
method_name: str,
run_name: str,
serialization: str,
stream_logs: bool,
run_async: bool,
save: bool,
remote: bool,
data: Any = None,
resource_address=None,
):
response = await self.async_session.post(
self._formatted_url(f"{key}/{method_name}"),
json=CallParams(
data=serialize_data(data, serialization),
serialization=serialization,
run_name=run_name,
stream_logs=stream_logs,
save=save,
remote=remote,
run_async=run_async,
).dict(),
headers=rns_client.request_headers(resource_address),
)
if response.status_code != 200:
raise ValueError(
f"Error calling {method_name} on server: {response.content.decode()}"
)

resp_json = response.json()
return resp_json

async def _alogs_request(
self,
run_name: str,
serialization: str,
error_str: str,
resource_address=None,
create_async_client=False,
) -> None:
# When running this in another thread, we need to explicitly create an async client here. When running within
# the main thread, we can use the client that was passed in.
if create_async_client:
client = httpx.AsyncClient(auth=self.auth, verify=self.verify, timeout=None)
else:
client = self.async_session

async with client.stream(
"GET",
self._formatted_url(f"logs/{run_name}/{serialization}"),
headers=rns_client.request_headers(resource_address),
) as res:
async for response_json in res.aiter_lines():
resp = json.loads(response_json)
output_type = resp["output_type"]
if output_type not in [
OutputType.EXCEPTION,
OutputType.STDOUT,
OutputType.STDERR,
]:
raise ValueError(
f"Unexpected output type from logs function: {output_type}"
)
handle_response(
resp, output_type, error_str, log_formatter=self.log_formatter
)

async def acall_module_method(
self,
key: str,
Expand Down Expand Up @@ -550,47 +627,59 @@ async def acall_module_method(
+ (f".{method_name}" if method_name else "")
)
serialization = serialization or "pickle"
async with self.async_session.stream(
"POST",
self._formatted_url(f"{key}/{method_name}"),
json=CallParams(
data=serialize_data(data, serialization),
serialization=serialization,
error_str = f"Error calling {method_name} on {key} on server"

acall_request = asyncio.create_task(
self._acall_request(
key=key,
method_name=method_name,
run_name=run_name,
serialization=serialization,
stream_logs=stream_logs,
run_async=run_async,
save=save,
remote=remote,
run_async=run_async,
).dict(),
headers=rns_client.request_headers(resource_address),
) as res:
if res.status_code != 200:
raise ValueError(
f"Error calling {method_name} on server: {res.content.decode()}"
)
error_str = f"Error calling {method_name} on {key} on server"
data=data,
resource_address=resource_address,
)
)
alogs_request = asyncio.create_task(
self._alogs_request(
run_name=run_name,
serialization=serialization,
error_str=error_str,
resource_address=resource_address,
)
)

# We get back a stream of intermingled log outputs and results (maybe None, maybe error, maybe single result,
# maybe a stream of results), so we need to separate these out.
result = None
async for response_json in res.aiter_lines():
resp = json.loads(response_json)
output_type = resp["output_type"]
result = handle_response(
resp, output_type, error_str, log_formatter=self.log_formatter
output_type = None
function_result = None
for fut_result in asyncio.as_completed([acall_request, alogs_request]):
resp_json = await fut_result
# alogs_request returns None, acall_request returns a legitimate result
if resp_json is not None:
function_result = handle_response(
resp_json,
resp_json["output_type"],
error_str,
log_formatter=self.log_formatter,
)
result = self._process_call_result(result, system, output_type)
output_type = resp_json["output_type"]

end = time.time()
end = time.time()

if method_name:
log_str = (
f"Time to call {key}.{method_name}: {round(end - start, 2)} seconds"
)
else:
log_str = f"Time to get {key}: {round(end - start, 2)} seconds"
logging.info(log_str)
return result
function_result = self._process_call_result(
function_result, system, output_type
)

if method_name:
log_str = (
f"Time to call {key}.{method_name}: {round(end - start, 2)} seconds"
)
else:
log_str = f"Time to get {key}: {round(end - start, 2)} seconds"
logging.info(log_str)
return function_result

def put_object(self, key: str, value: Any, env=None):
return self.request_json(
Expand Down
Loading

0 comments on commit 7cb3b41

Please sign in to comment.