Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Monitor task: add timeout for monitor task, and support killing #247

Merged
merged 3 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions aiida_workgraph/cli/cmd_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,29 @@ def task_play(process, tasks, timeout, wait):
def task_skip(process, tasks, timeout, wait):
"""Skip task."""
from aiida.engine.processes import control
from aiida_workgraph.utils.control import skip_task
from aiida_workgraph.utils.control import skip_tasks

for task in tasks:
try:
skip_task(process, task, timeout, wait)
skip_tasks(process.pk, task, timeout, wait)

except control.ProcessTimeoutException as exception:
echo.echo_critical(f"{exception}\n{REPAIR_INSTRUCTIONS}")


@workgraph_task.command("kill")
@arguments.PROCESS()
@click.argument("tasks", nargs=-1)
@options.TIMEOUT()
@options.WAIT()
@decorators.with_dbenv()
def task_kill(process, tasks, timeout, wait):
"""Kill task."""
from aiida.engine.processes import control
from aiida_workgraph.utils.control import kill_tasks

print("tasks", tasks)
try:
kill_tasks(process.pk, tasks, timeout, wait)
except control.ProcessTimeoutException as exception:
echo.echo_critical(f"{exception}\n{REPAIR_INSTRUCTIONS}")
104 changes: 60 additions & 44 deletions aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def load_instance_state(
super().load_instance_state(saved_state, load_context)
# Load the context
self._context = saved_state[self._CONTEXT]
self._temp = {"awaitables": {}}

self.set_logger(self.node.logger)

Expand Down Expand Up @@ -251,11 +252,8 @@ def _resolve_awaitable(self, awaitable: Awaitable, value: t.Any) -> None:
raise AssertionError(f"Unsupported awaitable action: {awaitable.action}")

awaitable.resolved = True
# find awaitable in the self._awaitables with the same pk and remove it
for index, a in enumerate(self._awaitables):
if a.pk == awaitable.pk:
self._awaitables.pop(index)
break
# remove awaitabble from the list
self._awaitables = [a for a in self._awaitables if a.pk != awaitable.pk]

if not self.has_terminated():
# the process may be terminated, for example, if the process was killed or excepted
Expand Down Expand Up @@ -388,7 +386,7 @@ def _on_awaitable_finished(self, awaitable: Awaitable) -> None:

:param awaitable: an Awaitable instance
"""
print("on awaitable finished: ", awaitable.key)
self.logger.debug(f"Awaitable {awaitable.key} finished.")

if isinstance(awaitable.pk, int):
self.logger.info(
Expand Down Expand Up @@ -418,11 +416,24 @@ def _on_awaitable_finished(self, awaitable: Awaitable) -> None:
)
)
try:
results = awaitable.result()
self.set_normal_task_results(awaitable.key, results)
# if awaitable is cancelled, the result is None
if awaitable.cancelled():
self.set_task_state_info(awaitable.key, "state", "KILLED")
# set child tasks state to SKIPPED
self.set_tasks_state(
self.ctx._connectivity["child_node"][awaitable.key], "SKIPPED"
)
self.report(f"Task: {awaitable.key} cancelled.")
else:
results = awaitable.result()
self.set_normal_task_results(awaitable.key, results)
except Exception as e:
self.logger.error(f"Error in awaitable {awaitable.key}: {e}")
self.set_task_state_info(awaitable.key, "state", "FAILED")
# set child tasks state to SKIPPED
self.set_tasks_state(
self.ctx._connectivity["child_node"][awaitable.key], "SKIPPED"
)
self.report(f"Task: {awaitable.key} failed.")
self.run_error_handlers(awaitable.key)
value = None
Expand Down Expand Up @@ -460,6 +471,7 @@ def on_create(self) -> None:
self.node.label = wgdata["name"]

def setup(self) -> None:
"""Setup the variables in the context."""
# track if the awaitable callback is added to the runner
self.ctx._awaitable_actions = []
self.ctx._new_data = {}
Expand All @@ -472,6 +484,8 @@ def setup(self) -> None:
self.ctx._execution_count = 0
# init task results
self.set_task_results()
# data not to be persisted, because they are not serializable
self._temp = {"awaitables": {}}
# while workgraph
if self.ctx._workgraph["workgraph_type"].upper() == "WHILE":
self.ctx._max_iteration = self.ctx._workgraph.get("max_iteration", 1000)
Expand Down Expand Up @@ -577,15 +591,18 @@ def apply_task_actions(self, msg: dict) -> None:
if action.upper() == "RESET":
for name in tasks:
self.reset_task(name)
if action.upper() == "PAUSE":
elif action.upper() == "PAUSE":
for name in tasks:
self.pause_task(name)
if action.upper() == "PLAY":
elif action.upper() == "PLAY":
for name in tasks:
self.play_task(name)
if action.upper() == "SKIP":
elif action.upper() == "SKIP":
for name in tasks:
self.skip_task(name)
elif action.upper() == "KILL":
for name in tasks:
self.kill_task(name)

def reset_task(
self,
Expand Down Expand Up @@ -629,23 +646,43 @@ def skip_task(self, name: str) -> None:
self.set_task_state_info(name, "state", "SKIPPED")
self.report(f"Task {name} action: SKIP.")

def kill_task(self, name: str) -> None:
"""Kill task.
This is used to kill the awaitable and monitor task.
"""
if self.get_task_state_info(name, "state") in ["RUNNING"]:
if self.ctx._tasks[name]["metadata"]["node_type"].upper() in [
"AWAITABLE",
"MONITOR",
]:
try:
self._temp["awaitables"][name].cancel()
self.set_task_state_info(name, "state", "KILLED")
self.report(f"Task {name} action: KILLED.")
except Exception as e:
self.logger.error(f"Error in killing task {name}: {e}")

def continue_workgraph(self) -> None:
print("Continue workgraph.")
self.report("Continue workgraph.")
# self.update_workgraph_from_base()
task_to_run = []
for name, task in self.ctx._tasks.items():
# update task state
if self.get_task_state_info(task["name"], "state") in [
"CREATED",
"RUNNING",
"FINISHED",
"FAILED",
"SKIPPED",
]:
if (
self.get_task_state_info(task["name"], "state")
in [
"CREATED",
"RUNNING",
"FINISHED",
"FAILED",
"SKIPPED",
]
or name in self.ctx._executed_tasks
):
continue
ready, _ = self.is_task_ready_to_run(name)
if ready and self.task_should_run(name):
if ready:
task_to_run.append(name)
#
self.report("tasks ready to run: {}".format(",".join(task_to_run)))
Expand Down Expand Up @@ -898,24 +935,6 @@ def check_for_conditions(self) -> bool:
self.ctx._count += 1
return should_run

def task_should_run(self, name: str, uuid: str = None) -> bool:
"""Check if the task should not run.
If name not in executed tasks, return True.
If uuid is not None, check if the task with the same name and uuid is the first one.
In a extreme case, the engine try to run the same task multiple times at the same time.
We only allow the first one to run.
"""
name_and_uuids = [
label.split(".")
for label in self.ctx._executed_tasks
if label.split(".")[0] == name
]
if len(name_and_uuids) == 0:
return True
# find the index of current uuid
index = [i for i, item in enumerate(name_and_uuids) if item[1] == uuid][0]
return index == 0

def remove_executed_task(self, name: str) -> None:
"""Remove labels with name from executed tasks."""
self.ctx._executed_tasks = [
Expand All @@ -936,7 +955,6 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None
create_data_node,
update_nested_dict_with_special_keys,
)
from uuid import uuid4

for name in names:
# skip if the max number of awaitables is reached
Expand All @@ -959,11 +977,7 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None
# skip if the task is already executed
if name in self.ctx._executed_tasks:
continue
else:
uuid = str(uuid4())
self.ctx._executed_tasks.append(f"{name}.{uuid}")
if not self.task_should_run(name, uuid):
continue
self.ctx._executed_tasks.append(name)
print("-" * 60)

self.report(f"Run task: {name}, type: {task['metadata']['node_type']}")
Expand Down Expand Up @@ -1162,13 +1176,15 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None
for key in self.ctx._tasks[name]["metadata"]["args"]:
kwargs.pop(key, None)
# add function and interval to the args
args = [executor, kwargs.pop("interval", 1), *args]
args = [executor, kwargs.pop("interval"), kwargs.pop("timeout"), *args]
awaitable_target = asyncio.ensure_future(
self.run_executor(monitor, args, kwargs, var_args, var_kwargs),
loop=self.loop,
)
awaitable = self.construct_awaitable_function(name, awaitable_target)
self.set_task_state_info(name, "state", "RUNNING")
# save the awaitable to the temp, so that we can kill it if needed
self._temp["awaitables"][name] = awaitable_target
self.to_context(**{name: awaitable})
elif task["metadata"]["node_type"].upper() in ["NORMAL"]:
# Normal task is created by decoratoring a function with @task()
Expand Down
7 changes: 6 additions & 1 deletion aiida_workgraph/executors/monitors.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
async def monitor(function, interval, *args, **kwargs):
async def monitor(function, interval, timeout, *args, **kwargs):
"""Monitor the function until it returns True or the timeout is reached."""
import asyncio
import time

start_time = time.time()
while True:
result = function(*args, **kwargs)
if result:
break
if time.time() - start_time > timeout:
raise TimeoutError(f"Timeout reached for monitor function {function}")
await asyncio.sleep(interval)


Expand Down
12 changes: 9 additions & 3 deletions aiida_workgraph/tasks/monitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@ class TimeMonitor(Task):
node_type = "MONITOR"
catalog = "Monitor"
args = ["datetime"]
kwargs = ["interval"]
kwargs = ["interval", "timeout"]

def create_sockets(self) -> None:
self.inputs.clear()
self.outputs.clear()
self.inputs.new("workgraph.any", "datetime")
inp = self.inputs.new("workgraph.any", "interval")
inp.add_property("workgraph.any", default=1.0)
inp = self.inputs.new("workgraph.any", "timeout")
inp.add_property("workgraph.any", default=86400.0)
inp = self.inputs.new("workgraph.any", "_wait")
inp.link_limit = 100000
self.outputs.new("workgraph.any", "_wait")
Expand All @@ -37,14 +39,16 @@ class FileMonitor(Task):
node_type = "MONITOR"
catalog = "Monitor"
args = ["filepath"]
kwargs = ["interval"]
kwargs = ["interval", "timeout"]

def create_sockets(self) -> None:
self.inputs.clear()
self.outputs.clear()
self.inputs.new("workgraph.any", "filepath")
inp = self.inputs.new("workgraph.any", "interval")
inp.add_property("workgraph.any", default=1.0)
inp = self.inputs.new("workgraph.any", "timeout")
inp.add_property("workgraph.any", default=86400.0)
inp = self.inputs.new("workgraph.any", "_wait")
inp.link_limit = 100000
self.outputs.new("workgraph.any", "_wait")
Expand All @@ -64,7 +68,7 @@ class TaskMonitor(Task):
node_type = "MONITOR"
catalog = "Monitor"
args = ["task_name"]
kwargs = ["workgraph_pk", "workgraph_name", "interval"]
kwargs = ["interval", "timeout", "workgraph_pk", "workgraph_name"]

def create_sockets(self) -> None:
self.inputs.clear()
Expand All @@ -74,6 +78,8 @@ def create_sockets(self) -> None:
self.inputs.new("workgraph.any", "task_name")
inp = self.inputs.new("workgraph.any", "interval")
inp.add_property("workgraph.any", default=1.0)
inp = self.inputs.new("workgraph.any", "timeout")
inp.add_property("workgraph.any", default=86400.0)
inp = self.inputs.new("workgraph.any", "_wait")
inp.link_limit = 100000
self.outputs.new("workgraph.any", "_wait")
Expand Down
33 changes: 20 additions & 13 deletions aiida_workgraph/utils/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,24 +109,31 @@ def kill_tasks(pk: int, tasks: list, timeout: int = 5, wait: bool = False):
"WAITING",
"PAUSED",
]:
print("tasks", tasks)
for name in tasks:
if get_task_state_info(node, name, "state") == "PLANNED":
state = get_task_state_info(node, name, "state")
process = get_task_state_info(node, name, "process")
print("state", state)
print("process", process)
if state == "PLANNED":
create_task_action(pk, tasks, action="skip")
elif get_task_state_info(node, name, "state") in [
elif state in [
"CREATED",
"RUNNING",
"WAITING",
"PAUSED",
]:
try:
control.kill_processes(
[get_task_state_info(node, name, "process")],
all_entries=None,
timeout=5,
wait=False,
)
except Exception as e:
print(f"Kill task {name} failed: {e}")
elif get_task_state_info(node, name, "process").is_finished:
raise ValueError(f"Task {name} is already finished.")
if get_task_state_info(node, name, "process") is None:
print(f"Task {name} is not a AiiDA process.")
create_task_action(pk, tasks, action="kill")
else:
try:
control.kill_processes(
[get_task_state_info(node, name, "process")],
all_entries=None,
timeout=5,
wait=False,
)
except Exception as e:
print(f"Kill task {name} failed: {e}")
return True, ""
12 changes: 12 additions & 0 deletions aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,18 @@ def play_tasks(self, tasks: List[str]) -> None:
_, msg = play_tasks(self.process.pk, tasks)
return "Send message to play tasks."

def kill_tasks(self, tasks: List[str]) -> None:
"""Kill the given tasks"""

from aiida_workgraph.utils.control import kill_tasks

if self.process is None:
for name in tasks:
self.tasks[name].action = "KILL"
else:
_, msg = kill_tasks(self.process.pk, tasks)
return "Send message to kill tasks."

def continue_process(self):
"""Continue a saved process by sending the task to RabbitMA.
Use with caution, this may launch duplicate processes."""
Expand Down
Loading
Loading