Skip to content

Commit

Permalink
Make endpoints bulk-first
Browse files Browse the repository at this point in the history
  • Loading branch information
ferchault committed Jul 6, 2023
1 parent dd4829f commit e887d40
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 85 deletions.
2 changes: 2 additions & 0 deletions src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ testuser:
runuser:
docker build -q -t dev-user user
docker run -t --network="host" --rm -w="/root" -v "$(shell pwd)/.devhome:/root/" dev-user hmq $(ARGS)
worker:
cd computenode; PYTHONPATH="../user:$PYTHONPATH" rq worker -c config -w worker.CachedWorker
4 changes: 3 additions & 1 deletion src/computenode/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@
class CachedWorker(Worker):
def execute_job(self, job, queue):
hmq.api.warm_cache(job.kwargs["function"])
return super().execute_job(job, queue)
ret = super().execute_job(job, queue)
self.connection.hdel("id2id", job.kwargs["hmqid"])
return ret
5 changes: 4 additions & 1 deletion src/server/app/maintenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ def refill_redis(njobs: int):
logentries.append({"event": "task/fetch", "id": task["id"], "ts": time.time()})
if len(logentries) > 0:
auth.db.logs.insert_many(logentries)
idmapping = {}
for task in tasks:
q.enqueue("hmq.unwrap", **task)
job = q.enqueue("hmq.unwrap", **task)
idmapping[task["hmqid"]] = job.id
redis_conn.hset("id2id", mapping=idmapping)


def flow_control():
Expand Down
132 changes: 87 additions & 45 deletions src/server/app/routers/compute.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pydantic import BaseModel, Field
from typing import Optional, List
import hashlib
import rq
import base64
import time
import uuid
Expand All @@ -10,6 +10,7 @@

from .. import helpers
from .. import auth
from .. import maintenance

app = APIRouter()

Expand Down Expand Up @@ -69,8 +70,8 @@ class TaskSubmit(BaseModel):
digest: str = Field(..., description="SHA256 digest of calls")


@app.post("/task/submit", tags=["compute"])
def task_submit(body: TaskSubmit):
@app.post("/tasks/submit", tags=["compute"])
def tasks_submit(body: TaskSubmit):
verify_challenge(body.challenge)
# todo: verify digest

Expand Down Expand Up @@ -100,67 +101,108 @@ def task_submit(body: TaskSubmit):
return uuids


class TaskFetch(BaseModel):
count: int = Field(..., description="Number of tasks to fetch")
class TasksDelete(BaseModel):
challenge: str = Field(..., description="Encrypted challenge from /auth/challenge")
delete: List[str] = Field(..., description="List of task ids to cancel and delete.")


# @app.post("/task/fetch", tags=["compute"])
# def task_fetch(body: TaskFetch):
# verify_challenge(body.challenge)
@app.post("/tasks/delete", tags=["compute"])
def tasks_delete(body: TasksDelete):
verify_challenge(body.challenge)

# mongodb
logentries = []
existing = auth.db.tasks.find({"id": {"$in": body.delete}})
for task in existing:
logentries.append({"event": "task/delete", "id": task["id"], "ts": time.time()})
existing_ids = [_["id"] for _ in existing]
auth.db.tasks.delete_many({"id": {"$in": existing_ids}})
if len(logentries) > 0:
auth.db.logs.insert_many(logentries)

# rq
for rqid in maintenance.redis_conn.hmget("id2id", existing_ids):
if rqid is None:
continue

job = rq.job.Job.fetch(rqid, connection=maintenance.redis_conn)
job.cancel()
job.delete()
maintenance.redis_conn.hdel("id2id", existing_ids)

return existing_ids

# tasks = []
# logentries = []
# for i in range(body.count):
# task = auth.db.tasks.find_one_and_update(
# {"status": "pending"},
# {"$set": {"status": "queued"}},
# return_document=True,
# )
# if task is None:
# break
# tasks.append(
# {"call": task["call"], "function": task["function"], "id": task["id"]}
# )
# logentries.append({"event": "task/fetch", "id": task["id"], "ts": time.time()})

# if len(logentries) > 0:
# auth.db.logs.insert_many(logentries)
# return tasks
class TasksInspect(BaseModel):
challenge: str = Field(..., description="Encrypted challenge from /auth/challenge")
tasks: List[str] = Field(
..., description="List of task ids of which to query the status."
)


@app.post("/tasks/inspect", tags=["compute"])
def task_inspect(body: TasksInspect):
verify_challenge(body.challenge)

result = {_: None for _ in body.tasks}
for task in auth.db.tasks.find({"id": {"$in": body.tasks}}):
result[task["id"]] = task["status"]

return result


class TaskResult(BaseModel):
id: str = Field(..., description="Task ID")
task: str = Field(..., description="Task ID")
result: str = Field(description="Base64-encoded and encrypted result", default=None)
error: str = Field(
description="Base64-encoded and encrypted error message", default=None
)
duration: float = Field(..., description="Duration of task execution")


class ResultsStore(BaseModel):
results: List[TaskResult] = Field(..., description="List of results to store")
challenge: str = Field(..., description="Encrypted challenge from /auth/challenge")


@app.post("/task/result", tags=["compute"])
def task_result(body: TaskResult):
@app.post("/results/store", tags=["compute"])
def results_store(body: ResultsStore):
verify_challenge(body.challenge)

is_error = False
if body.error is not None:
is_error = True
for result in body.results:
is_error = False
if result.error is not None:
is_error = True

if is_error:
status = "error"
else:
status = "completed"

if is_error:
status = "error"
else:
status = "completed"
update = {
"status": status,
"result": result.result,
"error": result.error,
"duration": result.duration,
}
auth.db.tasks.update_one({"id": result.task}, {"$set": update})

update = {
"status": status,
"result": body.result,
"error": body.error,
"duration": body.duration,
}
auth.db.tasks.update_one({"id": body.id}, {"$set": update})

class ResultsRetrieve(BaseModel):
tasks: List[str] = Field(..., description="Task IDs")
challenge: str = Field(..., description="Encrypted challenge from /auth/challenge")


@app.post("/results/retrieve", tags=["compute"])
def results_retreive(body: ResultsRetrieve):
verify_challenge(body.challenge)

results = {_: None for _ in body.tasks}
for task in auth.db.tasks.find({"id": {"$in": body.tasks}}):
entry = {"status": None, "result": None, "error": None, "duration": None}
for key in entry.keys():
if key in entry:
entry[key] = task[key]
results[task["id"]] = entry

# task is getting old on compute node
# @app.post("/task/return", tags=["compute"])
# def task_resturn
return results
55 changes: 17 additions & 38 deletions src/user/hmq/hmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def submit_tasks(self, tag: str, digest: str, calls: list):
"digest": calldigest,
"challenge": self._encrypt(str(time.time()), raw=True),
}
uuids = self._post("/task/submit", payload)
uuids = self._post("/tasks/submit", payload)
return uuids

def warm_cache(self, function: str):
Expand Down Expand Up @@ -166,58 +166,27 @@ def get_challenge(box):

overall_end = time.time()
payload = {
"id": hmqid,
"task": hmqid,
"duration": endtime - starttime,
"walltime": overall_end - overall_start,
"challenge": get_challenge(box),
}

if errormsg is not None:
payload["error"] = errormsg
else:
payload["result"] = result
payload = {"results": [payload], "challenge": get_challenge(box)}

print(payload)
res = requests.post(
f"{baseurl}/task/result",
f"{baseurl}/results/store",
json=payload,
verify=verify,
)

def _get_challenge(self):
return str(time.time())

def _upload_results(self, jobid):
self._jobs = [_ for _ in self._jobs if _.jobid != jobid]

def fetch_new_tasks(self):
payload = {
"challenge": self._encrypt(self._get_challenge(), raw=True),
"count": 1000,
}
try:
tasks = self._post(f"/task/fetch", payload)
for task in tasks:
self._queue.put(task)
except:
time.sleep(2)

def worker(self):
self._build_box()

self._queue = mp.Queue()

self._pool = mp.Pool(
mp.cpu_count(),
API._worker,
(self._queue, self._box._key, self._url, self._verify),
)

while True:
while True:
if self._queue.qsize() < 20:
break
time.sleep(2)
self.fetch_new_tasks()


api = API()

Expand Down Expand Up @@ -264,10 +233,20 @@ def from_file(filename):

def wait(self, keep=False):
"""Waits for all tasks in a tag and deletes them from the queue unless keep is specified."""
missing = [
_ for _ in self.tasks if _ not in self._results and _ not in self._errors
]

while len(self.tasks) > len(self._results) + len(self._errors):
time.sleep(1)

# retrieve all
...
payload = {
"challenge": self._encrypt(self._get_challenge(), raw=True),
"count": 1000,
}
# try:
# tasks = self._post(f"/task/fetch", payload)

def retrieve(self):
...
Expand Down

0 comments on commit e887d40

Please sign in to comment.