Skip to content

Commit

Permalink
Fetching results
Browse files Browse the repository at this point in the history
  • Loading branch information
ferchault committed Jul 7, 2023
1 parent e887d40 commit a43188c
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 35 deletions.
9 changes: 6 additions & 3 deletions src/server/app/routers/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,13 @@ def results_retreive(body: ResultsRetrieve):

results = {_: None for _ in body.tasks}
for task in auth.db.tasks.find({"id": {"$in": body.tasks}}):
entry = {"status": None, "result": None, "error": None, "duration": None}
entry = {"result": None, "error": None}
has_info = False
for key in entry.keys():
if key in entry:
if key in task:
entry[key] = task[key]
results[task["id"]] = entry
has_info = True
if has_info:
results[task["id"]] = entry

return results
136 changes: 104 additions & 32 deletions src/user/hmq/hmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,15 @@ def _encrypt(self, obj, raw=False):
message = base64.b64encode(self._box.encrypt(message)).decode("ascii")
return message

def _decrypt(self, obj):
self._build_box()

message = obj.encode("utf-8")
message = json.loads(
self._box.decrypt(base64.b64decode(message.decode("ascii")))
)
return message

def _post(self, endpoint, payload):
response = requests.post(
self._url + endpoint, json=payload, verify=self._verify
Expand Down Expand Up @@ -111,6 +120,22 @@ def register_function(self, remote_function: dict, digest: str):
pass
return False

def retrieve_results(self, tasks: list[str]):
payload = {
"tasks": tasks,
"challenge": self._encrypt(str(time.time()), raw=True),
}
results = {}
for task, result in self._post("/results/retrieve", payload).items():
if result is None:
results[task] = result
continue
for key in "result error".split():
if result[key] is not None:
result[key] = self._decrypt(result[key])
results[task] = result
return results

def submit_tasks(self, tag: str, digest: str, calls: list):
calls = [self._encrypt(call) for call in calls]
callstr = json.dumps(calls)
Expand Down Expand Up @@ -156,13 +181,17 @@ def get_challenge(box):
starttime = time.time()
try:
result = functions[function](*args, **kwargs)
result = base64.b64encode(
box.encrypt(json.dumps(result).encode("utf8"))
).decode("ascii")
errormsg = None
except:
errormsg = traceback.format_exc()
errormsg = base64.b64encode(
box.encrypt(json.dumps(errormsg).encode("utf8"))
).decode("ascii")
result = None
endtime = time.time()
result = base64.b64encode(
box.encrypt(json.dumps(result).encode("utf8"))
).decode("ascii")

overall_end = time.time()
payload = {
Expand All @@ -177,7 +206,6 @@ def get_challenge(box):
payload["result"] = result
payload = {"results": [payload], "challenge": get_challenge(box)}

print(payload)
res = requests.post(
f"{baseurl}/results/store",
json=payload,
Expand Down Expand Up @@ -206,50 +234,94 @@ def to_file(self, filename):
"name": self.name,
"ntasks": len(self.tasks),
"status": {
"PENDING": len(self.tasks) - len(self._results),
"PENDING": len(self.tasks) - len(self._results) - len(self._errors),
"DONE": len(self._results),
"FAILED": len(self._errors),
},
}
payload = json.dumps(
{"tasks": self.tasks, "results": self._results, "errors": self._errors}
)
payload = ""
for task in self.tasks:
result = None
if task in self._results:
result = self._results[task]
error = None
if task in self._errors:
error = self._errors[task]
payload += (
json.dumps({"task": task, "result": result, "error": error}) + "\n"
)

with open(filename, "w") as f:
f.write(json.dumps(meta) + "\n" + payload)

@staticmethod
def from_file(filename):
meta = None
tasks = []
results = {}
errors = {}
for line in open(filename):
if meta is None:
meta = json.loads(line.strip())
else:
payload = json.loads(line.strip())
t = Tag(meta["name"])
t.tasks = payload["tasks"]
t._results = payload["results"]
t._errors = payload["errors"]
row = json.loads(line.strip())
tasks.append(row["task"])
if row["result"] is not None:
results[row["task"]] = row["result"]
if row["error"] is not None:
errors[row["task"]] = row["error"]

t = Tag(meta["name"])
t.tasks = tasks
t._results = results
t._errors = errors
return t

def wait(self, keep=False):
"""Waits for all tasks in a tag and deletes them from the queue unless keep is specified."""
missing = [
_ for _ in self.tasks if _ not in self._results and _ not in self._errors
]

while len(self.tasks) > len(self._results) + len(self._errors):
time.sleep(1)

# retrieve all
payload = {
"challenge": self._encrypt(self._get_challenge(), raw=True),
"count": 1000,
}
# try:
# tasks = self._post(f"/task/fetch", payload)

def retrieve(self):
...
def _pull_batch(self, tasklist):
if len(tasklist) > 0:
results = api.retrieve_results(tasklist)
for task, result in results.items():
if result is not None:
if result["result"] is not None:
self._results[task] = result["result"]
else:
self._errors[task] = result["error"]

def pull(self, blocking=False) -> int:
remaining = len(self.tasks) - len(self._results) - len(self._errors)
while remaining > 0:
tasklist = []
for task in self.tasks:
if task not in self._results and task not in self._errors:
tasklist.append(task)
if len(tasklist) > 50:
self._pull_batch(tasklist)
tasklist = []
self._pull_batch(tasklist)
remaining = len(self.tasks) - len(self._results) - len(self._errors)
if not blocking:
break
return remaining

@property
def results(self):
res = []
for task in self.tasks:
if task in self._results:
res.append(self._results[task])
else:
res.append(None)
return res

@property
def errors(self):
res = []
for task in self.tasks:
if task in self._errors:
res.append(self._errors[task])
else:
res.append(None)
return res


def setup(url, key):
Expand Down

0 comments on commit a43188c

Please sign in to comment.