Skip to content

Commit

Permalink
Add If task (#222)
Browse files Browse the repository at this point in the history
The `If Task` is visually represented as an "If Zone." This zone encapsulates all its child tasks, which are executed based on the defined conditions.

- **Conditions**: The `If` Task includes a `conditions` socket, which determines when the tasks inside the zone should be executed.
- **Invert_condition**: If this input is True, it will invert the conditions.
- **Task Linking**: Tasks located outside the If Zone can be directly linked to tasks within the zone, allowing for dynamic workflow adjustments based on conditional outcomes.
  • Loading branch information
superstar54 committed Aug 13, 2024
1 parent 0692239 commit f003c34
Show file tree
Hide file tree
Showing 14 changed files with 2,776 additions and 1,437 deletions.
3 changes: 3 additions & 0 deletions aiida_workgraph/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def new(
if isinstance(identifier, str) and identifier.upper() == "WHILE":
task = super().new("workgraph.while", name, uuid, **kwargs)
return task
if isinstance(identifier, str) and identifier.upper() == "IF":
task = super().new("workgraph.if", name, uuid, **kwargs)
return task
if isinstance(identifier, WorkGraph):
identifier = build_task_from_workgraph(identifier)
return super().new(identifier, name, uuid, **kwargs)
Expand Down
45 changes: 34 additions & 11 deletions aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,15 +654,15 @@ def update_parent_task_state(self, name: str) -> None:
"""Update parent task state."""
parent_task = self.ctx.tasks[name]["parent_task"]
if parent_task[0]:
if (
self.ctx.tasks[parent_task[0]]["metadata"]["node_type"].upper()
== "WHILE"
):
task_type = self.ctx.tasks[parent_task[0]]["metadata"]["node_type"].upper()
if task_type == "WHILE":
self.update_while_task_state(parent_task[0])
elif task_type == "IF":
self.update_if_task_state(parent_task[0])

def update_while_task_state(self, name: str) -> None:
"""Update while task state."""
finished, _ = self.is_while_task_finished(name)
finished, _ = self.are_childen_finished(name)

if finished:
should_run = self.should_run_while_task(name)
Expand All @@ -672,7 +672,14 @@ def update_while_task_state(self, name: str) -> None:
self.reset_task(name, reset_execution_count=False)
else:
self.set_task_state_info(name, "state", "FINISHED")
self.update_parent_task_state(name)
self.update_parent_task_state(name)

def update_if_task_state(self, name: str) -> None:
"""Update if task state."""
finished, _ = self.are_childen_finished(name)
if finished:
self.set_task_state_info(name, "state", "FINISHED")
self.update_parent_task_state(name)

def should_run_while_task(self, name: str) -> tuple[bool, t.Any]:
"""Check if the while task should run."""
Expand All @@ -688,8 +695,16 @@ def should_run_while_task(self, name: str) -> tuple[bool, t.Any]:
conditions.append(value)
return False not in conditions

def is_while_task_finished(self, name: str) -> tuple[bool, t.Any]:
"""Check if the while task is finished."""
def should_run_if_task(self, name: str) -> tuple[bool, t.Any]:
"""Check if the IF task should run."""
_, kwargs, _, _, _ = self.get_inputs(name)
flag = kwargs["conditions"]
if kwargs["invert_condition"]:
return not flag
return flag

def are_childen_finished(self, name: str) -> tuple[bool, t.Any]:
"""Check if the child tasks are finished."""
task = self.ctx.tasks[name]
finished = True
for name in task["children"]:
Expand Down Expand Up @@ -874,7 +889,7 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None
# print("executor: ", task["executor"])
executor, _ = get_executor(task["executor"])
# print("executor: ", executor)
args, kwargs, var_args, var_kwargs, args_dict = self.get_inputs(task)
args, kwargs, var_args, var_kwargs, args_dict = self.get_inputs(name)
for i, key in enumerate(self.ctx.tasks[name]["metadata"]["args"]):
kwargs[key] = args[i]
# update the port namespace
Expand Down Expand Up @@ -1030,10 +1045,17 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None
task["execution_count"] += 1
self.set_task_state_info(name, "state", "RUNNING")
else:
# if the first run is skipped, we set the child tasks to SKIPPED
self.set_tasks_state(task["children"], "FINISHED")
self.update_while_task_state(name)
self.continue_workgraph()
elif task["metadata"]["node_type"].upper() in ["IF"]:
should_run = self.should_run_if_task(name)
if should_run:
self.set_task_state_info(name, "state", "RUNNING")
else:
self.set_tasks_state(task["children"], "SKIPPED")
self.update_if_task_state(name)
self.continue_workgraph()
elif task["metadata"]["node_type"].upper() in ["NORMAL"]:
print("Task type: Normal.")
# normal function does not have a process
Expand Down Expand Up @@ -1072,7 +1094,7 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None
return self.exit_codes.UNKNOWN_TASK_TYPE

def get_inputs(
self, task: t.Dict[str, t.Any]
self, name: str
) -> t.Tuple[
t.List[t.Any],
t.Dict[str, t.Any],
Expand All @@ -1087,6 +1109,7 @@ def get_inputs(
kwargs = {}
var_args = None
var_kwargs = None
task = self.ctx.tasks[name]
properties = task.get("properties", {})
# TODO: check if input is linked, otherwise use the property value
inputs = {}
Expand Down
19 changes: 19 additions & 0 deletions aiida_workgraph/tasks/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,25 @@ def create_sockets(self) -> None:
self.outputs.new("workgraph.any", "_wait")


class If(Task):
"""If task"""

identifier = "workgraph.if"
name = "If"
node_type = "IF"
catalog = "Control"
kwargs = ["conditions", "invert_condition"]

def create_sockets(self) -> None:
self.inputs.clear()
self.outputs.clear()
inp = self.inputs.new("workgraph.any", "_wait")
inp.link_limit = 100000
self.inputs.new("workgraph.any", "conditions")
self.inputs.new("workgraph.any", "invert_condition")
self.outputs.new("workgraph.any", "_wait")


class Gather(Task):
"""Gather"""

Expand Down
2 changes: 1 addition & 1 deletion aiida_workgraph/web/frontend/src/rete/default.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ export async function loadJSON(editor: NodeEditor<any>, area: any, workgraphData
const nodeData = workgraphData.nodes[nodeId];
// if node_type is "WHILE", find all
console.log("Node type: ", nodeData['node_type']);
if (nodeData['node_type'] === "WHILE") {
if (nodeData['node_type'] === "WHILE" || nodeData['node_type'] === "IF") {
// find the node
const node = nodeMap[nodeData.label];
const children = nodeData['children'];
Expand Down
8 changes: 4 additions & 4 deletions aiida_workgraph/widget/js/default_rete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,13 @@ export async function loadJSON(editor, area, layout, workgraphData) {
for (const nodeId in workgraphData.nodes) {
const nodeData = workgraphData.nodes[nodeId];
// if node_type is "WHILE", find all
if (nodeData['node_type'] === "WHILE") {
if (nodeData['node_type'] === "WHILE" || nodeData['node_type'] === "IF") {
// find the node
const node = editor.nodeMap[nodeData.label];
const tasks = nodeData['properties']['tasks']['value'];
const children = nodeData['children'];
// find the id of all nodes in the editor that has a label in while_zone
for (const nodeId in tasks) {
const node1 = editor.nodeMap[tasks[nodeId]];
for (const nodeId in children) {
const node1 = editor.nodeMap[children[nodeId]];
node1.parent = node.id;
}
}
Expand Down
8 changes: 4 additions & 4 deletions aiida_workgraph/widget/src/widget/html_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,13 @@ class Connection extends ClassicPreset.Connection {}
for (const nodeId in workgraphData.nodes) {
const nodeData = workgraphData.nodes[nodeId];
// if node_type is "WHILE", find all
if (nodeData['node_type'] === "WHILE") {
if (nodeData['node_type'] === "WHILE" || nodeData['node_type'] === "IF") {
// find the node
const node = editor.nodeMap[nodeData.label];
const tasks = nodeData['properties']['tasks']['value'];
const children = nodeData['children'];
// find the id of all nodes in the editor that has a label in while_zone
for (const nodeId in tasks) {
const node1 = editor.nodeMap[tasks[nodeId]];
for (const nodeId in children) {
const node1 = editor.nodeMap[children[nodeId]];
node1.parent = node.id;
}
}
Expand Down
Loading

0 comments on commit f003c34

Please sign in to comment.