Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TaskCollection for waiting_on and children. #223

Merged
merged 5 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 76 additions & 12 deletions aiida_workgraph/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
WorkGraphOutputSocketCollection,
)
import aiida
from typing import Any, Dict, Optional, Union, Callable, List
from typing import Any, Dict, Optional, Union, Callable, List, Set, Iterable
from aiida_workgraph.utils.message import WIDGET_INSTALLATION_MESSAGE


Expand All @@ -29,8 +29,6 @@ class Task(GraphNode):
def __init__(
self,
context_mapping: Optional[List[Any]] = None,
wait: List[Union[str, GraphNode]] = None,
children: List[Union[str, GraphNode]] = None,
process: Optional[aiida.orm.ProcessNode] = None,
pk: Optional[int] = None,
**kwargs: Any,
Expand All @@ -45,8 +43,7 @@ def __init__(
**kwargs,
)
self.context_mapping = {} if context_mapping is None else context_mapping
self.wait = [] if wait is None else wait
self.children = [] if children is None else children
self.waiting_on = TaskCollection(parent=self)
self.process = process
self.pk = pk
if USE_WIDGET:
Expand All @@ -64,12 +61,8 @@ def to_dict(self) -> Dict[str, Any]:

tdata = super().to_dict()
tdata["context_mapping"] = self.context_mapping
tdata["wait"] = [
task if isinstance(task, str) else task.name for task in self.wait
]
tdata["children"] = [
task if isinstance(task, str) else task.name for task in self.children
]
tdata["wait"] = [task.name for task in self.waiting_on]
tdata["children"] = []
tdata["execution_count"] = 0
tdata["parent_task"] = [None]
tdata["process"] = serialize(self.process) if self.process else serialize(None)
Expand Down Expand Up @@ -122,7 +115,7 @@ def from_dict(cls, data: Dict[str, Any], task_pool: Optional[Any] = None) -> "Ta

task = super().from_dict(data, node_pool=task_pool)
task.context_mapping = data.get("context_mapping", {})
task.wait = data.get("wait", [])
task.waiting_on.add(data.get("wait", []))
task.process = data.get("process", None)

return task
Expand Down Expand Up @@ -150,3 +143,74 @@ def to_html(self, output: str = None, **kwargs):
return
self._widget.from_node(self)
return self._widget.to_html(output=output, **kwargs)


class TaskCollection:
def __init__(self, parent: "Task"):
self._items: Set[str] = set()
self.parent = parent
self._top_parent = None

@property
def graph(self) -> "WorkGraph":
"""Cache and return the top parent of the collection."""
if not self._top_parent:
parent = self.parent
while getattr(parent, "parent", None):
parent = parent.parent
self._top_parent = parent
return self._top_parent

@property
def items(self) -> Set[str]:
return self._items

def _normalize_tasks(
self, tasks: Union[List[Union[str, Task]], str, Task]
) -> Iterable[str]:
"""Normalize input to an iterable of task names."""
if isinstance(tasks, (str, Task)):
tasks = [tasks]
task_objects = []
for task in tasks:
if isinstance(task, str):
if task not in self.graph.tasks.keys():
raise ValueError(
f"Task '{task}' is not in the graph. Available tasks: {self.graph.tasks.keys()}"
)
task_objects.append(self.graph.tasks[task])
elif isinstance(task, Task):
task_objects.append(task)
else:
raise ValueError(f"Invalid task type: {type(task)}")
return task_objects

def add(self, tasks: Union[List[Union[str, Task]], str, Task]) -> None:
"""Add tasks to the collection. Tasks can be a list or a single Task or task name."""
for task in self._normalize_tasks(tasks):
self._items.add(task)

def remove(self, tasks: Union[List[Union[str, Task]], str, Task]) -> None:
"""Remove tasks from the collection. Tasks can be a list or a single Task or task name."""
for task in self._normalize_tasks(tasks):
if task not in self._items:
raise ValueError(f"Task '{task.name}' is not in the collection.")
self._items.remove(task)

def clear(self) -> None:
"""Clear all items from the collection."""
self._items.clear()

def __contains__(self, item: str) -> bool:
"""Check if a task name is in the collection."""
return item in self._items

def __iter__(self):
"""Yield each task name in the collection for iteration."""
return iter(self._items)

def __len__(self) -> int:
return len(self._items)

def __repr__(self) -> str:
return f"{self._items}"
25 changes: 20 additions & 5 deletions aiida_workgraph/tasks/builtins.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,39 @@
from typing import Dict
from aiida_workgraph.task import Task
from typing import Any, Dict
from aiida_workgraph.task import Task, TaskCollection


class Zone(Task):
"""Zone"""
"""
Extend the Task class to include a 'children' attribute.
"""

identifier = "workgraph.zone"
name = "Zone"
node_type = "ZONE"
catalog = "Control"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.children = TaskCollection(parent=self)

def create_sockets(self) -> None:
self.inputs.clear()
self.outputs.clear()
inp = self.inputs.new("workgraph.any", "_wait")
inp.link_limit = 100000
self.outputs.new("workgraph.any", "_wait")

def to_dict(self) -> Dict[str, Any]:
tdata = super().to_dict()
tdata["children"] = [task.name for task in self.children]
return tdata

def from_dict(self, data: Dict[str, Any]) -> None:
super().from_dict(data)
self.children.add(data.get("children", []))


class While(Task):
class While(Zone):
"""While"""

identifier = "workgraph.while"
Expand All @@ -38,7 +53,7 @@ def create_sockets(self) -> None:
self.outputs.new("workgraph.any", "_wait")


class If(Task):
class If(Zone):
"""If task"""

identifier = "workgraph.if"
Expand Down
14 changes: 7 additions & 7 deletions aiida_workgraph/web/frontend/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,6 @@ def extend(self, wg: "WorkGraph", prefix: str = "") -> None:
"""
for task in wg.tasks:
task.name = prefix + task.name
task.wait = [prefix + w for w in task.wait] if task.wait else []
task.parent = self
self.tasks.append(task)
# self.sequence.extend([prefix + task for task in wg.sequence])
Expand Down
2 changes: 1 addition & 1 deletion docs/source/howto/html/test_zone.html
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
const { RenderUtils } = ReteRenderUtils;
const styled = window.styled;

const workgraphData = {"name": "test_zone", "uuid": "a9c614a4-5a09-11ef-aedd-906584de3e5b", "state": "CREATED", "nodes": {"add1": {"label": "add1", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "a9d08812-5a09-11ef-aedd-906584de3e5b", "node_uuid": "a9d08100-5a09-11ef-aedd-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "a9d088d0-5a09-11ef-aedd-906584de3e5b", "node_uuid": "a9d08100-5a09-11ef-aedd-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}], "outputs": [{"name": "result"}], "position": [30, 30], "children": []}, "add2": {"label": "add2", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "a9d9fe74-5a09-11ef-aedd-906584de3e5b", "node_uuid": "a9d9f8a2-5a09-11ef-aedd-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "a9d9fec4-5a09-11ef-aedd-906584de3e5b", "node_uuid": "a9d9f8a2-5a09-11ef-aedd-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}], "outputs": [], "position": [60, 60], "children": []}, "add3": {"label": "add3", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "a9e2d328-5a09-11ef-aedd-906584de3e5b", "node_uuid": "a9e2cd38-5a09-11ef-aedd-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "a9e2d38c-5a09-11ef-aedd-906584de3e5b", "node_uuid": "a9e2cd38-5a09-11ef-aedd-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [{"from_node": "add1", "from_socket": "result", "from_socket_uuid": "a9d089f2-5a09-11ef-aedd-906584de3e5b"}], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y"}], "outputs": [{"name": "result"}, {"name": "result"}], "position": [90, 90], "children": []}, "add4": {"label": "add4", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "a9eb9f8a-5a09-11ef-aedd-906584de3e5b", "node_uuid": "a9eb9404-5a09-11ef-aedd-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "a9eb9fe4-5a09-11ef-aedd-906584de3e5b", "node_uuid": "a9eb9404-5a09-11ef-aedd-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [{"from_node": "add3", "from_socket": "result", "from_socket_uuid": "a9e2d45e-5a09-11ef-aedd-906584de3e5b"}], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y"}], "outputs": [], "position": [120, 120], "children": []}, "add5": {"label": "add5", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "a9f6b820-5a09-11ef-aedd-906584de3e5b", "node_uuid": "a9f6b01e-5a09-11ef-aedd-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "a9f6b8ac-5a09-11ef-aedd-906584de3e5b", "node_uuid": "a9f6b01e-5a09-11ef-aedd-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [{"from_node": "add3", "from_socket": "result", "from_socket_uuid": "a9e2d45e-5a09-11ef-aedd-906584de3e5b"}], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y"}], "outputs": [], "position": [150, 150], "children": []}, "Zone1": {"label": "Zone1", "node_type": "ZONE", "inputs": [], "outputs": [], "position": [180, 180], "children": ["add2", "add3", "add4"]}}, "links": [{"from_socket": "result", "from_node": "add1", "from_socket_uuid": "a9d089f2-5a09-11ef-aedd-906584de3e5b", "to_socket": "y", "to_node": "add3", "state": false}, {"from_socket": "result", "from_node": "add3", "from_socket_uuid": "a9e2d45e-5a09-11ef-aedd-906584de3e5b", "to_socket": "y", "to_node": "add4", "state": false}, {"from_socket": "result", "from_node": "add3", "from_socket_uuid": "a9e2d45e-5a09-11ef-aedd-906584de3e5b", "to_socket": "y", "to_node": "add5", "state": false}]}
const workgraphData = {"name": "test_zone", "uuid": "55e69bb8-5a0b-11ef-8005-906584de3e5b", "state": "CREATED", "nodes": {"add1": {"label": "add1", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "55f0e046-5a0b-11ef-8005-906584de3e5b", "node_uuid": "55f0d97a-5a0b-11ef-8005-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "55f0e0fa-5a0b-11ef-8005-906584de3e5b", "node_uuid": "55f0d97a-5a0b-11ef-8005-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}], "outputs": [{"name": "result"}], "position": [30, 30], "children": []}, "add2": {"label": "add2", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "55fa09c8-5a0b-11ef-8005-906584de3e5b", "node_uuid": "55fa0428-5a0b-11ef-8005-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "55fa0a18-5a0b-11ef-8005-906584de3e5b", "node_uuid": "55fa0428-5a0b-11ef-8005-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}], "outputs": [], "position": [60, 60], "children": []}, "add3": {"label": "add3", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "5602089e-5a0b-11ef-8005-906584de3e5b", "node_uuid": "5602031c-5a0b-11ef-8005-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "560208ee-5a0b-11ef-8005-906584de3e5b", "node_uuid": "5602031c-5a0b-11ef-8005-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [{"from_node": "add1", "from_socket": "result", "from_socket_uuid": "55f0e212-5a0b-11ef-8005-906584de3e5b"}], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y"}], "outputs": [{"name": "result"}, {"name": "result"}], "position": [90, 90], "children": []}, "add4": {"label": "add4", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "560a9a18-5a0b-11ef-8005-906584de3e5b", "node_uuid": "560a94b4-5a0b-11ef-8005-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "560a9a72-5a0b-11ef-8005-906584de3e5b", "node_uuid": "560a94b4-5a0b-11ef-8005-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [{"from_node": "add3", "from_socket": "result", "from_socket_uuid": "560209ac-5a0b-11ef-8005-906584de3e5b"}], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y"}], "outputs": [], "position": [120, 120], "children": []}, "add5": {"label": "add5", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "561355b8-5a0b-11ef-8005-906584de3e5b", "node_uuid": "56134fc8-5a0b-11ef-8005-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "56135608-5a0b-11ef-8005-906584de3e5b", "node_uuid": "56134fc8-5a0b-11ef-8005-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [{"from_node": "add3", "from_socket": "result", "from_socket_uuid": "560209ac-5a0b-11ef-8005-906584de3e5b"}], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y"}], "outputs": [], "position": [150, 150], "children": []}, "Zone1": {"label": "Zone1", "node_type": "ZONE", "inputs": [], "outputs": [], "position": [180, 180], "children": ["add3", "add2", "add4"]}}, "links": [{"from_socket": "result", "from_node": "add1", "from_socket_uuid": "55f0e212-5a0b-11ef-8005-906584de3e5b", "to_socket": "y", "to_node": "add3", "state": false}, {"from_socket": "result", "from_node": "add3", "from_socket_uuid": "560209ac-5a0b-11ef-8005-906584de3e5b", "to_socket": "y", "to_node": "add4", "state": false}, {"from_socket": "result", "from_node": "add3", "from_socket_uuid": "560209ac-5a0b-11ef-8005-906584de3e5b", "to_socket": "y", "to_node": "add5", "state": false}]}

// Define Schemes to use in vanilla JS
const Schemes = {
Expand Down
Loading
Loading