Skip to content

Commit

Permalink
if user already defined the var_args in the inputs, skip adding it (#264
Browse files Browse the repository at this point in the history
)
  • Loading branch information
superstar54 committed Aug 23, 2024
1 parent fc2cd43 commit db2ddee
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 10 deletions.
16 changes: 9 additions & 7 deletions aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,15 @@ def build_task_from_AiiDA(
or executor.process_class._var_positional
)
tdata["var_kwargs"] = name
inputs.append(
{
"identifier": "workgraph.any",
"name": name,
"property": {"identifier": "workgraph.any", "default": {}},
}
)
# if user already defined the var_args in the inputs, skip it
if name not in [input["name"] for input in inputs]:
inputs.append(
{
"identifier": "workgraph.any",
"name": name,
"property": {"identifier": "workgraph.any", "default": {}},
}
)
# TODO In order to reload the WorkGraph from process, "is_pickle" should be True
# so I pickled the function here, but this is not necessary
# we need to update the node_graph to support the path and name of the function
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [
"numpy~=1.21",
"scipy",
"ase",
"node-graph>=0.0.13",
"node-graph>=0.0.14",
"aiida-core>=2.3",
"cloudpickle",
"aiida-shell",
Expand Down
19 changes: 17 additions & 2 deletions tests/test_decorator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
from aiida_workgraph import WorkGraph
from aiida_workgraph import WorkGraph, task
from typing import Callable
from aiida_workgraph import task


@pytest.fixture(params=["decorator_factory", "decorator"])
Expand Down Expand Up @@ -117,6 +116,22 @@ def test_decorators_workfunction_args(task_workfunction) -> None:
assert n.outputs.keys() == ["result", "_outputs", "_wait"]


def test_decorators_parameters() -> None:
"""Test passing parameters to decorators."""

@task.calcfunction(
inputs=[{"name": "c", "link_limit": 1000}],
outputs=[{"name": "sum"}, {"name": "product"}],
)
def test(a, b=1, **c):
return {"sum": a + b, "product": a * b}

test1 = test.task()
assert test1.inputs["c"].link_limit == 1000
assert "sum" in test1.outputs.keys()
assert "product" in test1.outputs.keys()


@pytest.fixture(params=["decorator_factory", "decorator"])
def task_graph_builder(request):
if request.param == "decorator_factory":
Expand Down

0 comments on commit db2ddee

Please sign in to comment.