Skip to content

Commit

Permalink
add worktree.save
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Dec 4, 2023
1 parent 4e835e9 commit cc57989
Show file tree
Hide file tree
Showing 7 changed files with 557 additions and 50 deletions.
49 changes: 23 additions & 26 deletions aiida_worktree/engine/worktree.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
def define(cls, spec):
super().define(spec)
spec.input("input_file", valid_type=orm.SinglefileData, required=False)
spec.input_namespace("nt", dynamic=True, required=False)
spec.input_namespace("worktree", dynamic=True, required=False)
spec.input_namespace("input_nodes", dynamic=True, required=False)
spec.exit_code(2, "ERROR_SUBPROCESS", message="A subprocess has failed.")

Expand Down Expand Up @@ -388,28 +388,18 @@ def setup(self):
self.ctx._awaitable_actions = []
self.ctx.new_data = dict()
self.ctx.input_nodes = dict()
if "input_file" in self.inputs:
ext = splitext(self.inputs["input_file"].filename)[1]
with self.inputs["input_file"].open(mode="r") as f:
if ext == ".yaml":
ntdata = yaml.safe_load(f)
else:
raise Exception("Please use a yaml file.")
elif "nt" in self.inputs:
ntdata = self.inputs["nt"]
else:
raise Exception("Please set input!")

# ntdata = jsonref.JsonRef.replace_refs(tntdata, loader = JsonYamlLoader())
build_node_link(ntdata)
self.init_ctx(ntdata["ctx"])
self.ctx.nodes = ntdata["nodes"]
self.ctx.links = ntdata["links"]
self.ctx.ctrl_links = ntdata["ctrl_links"]
self.ctx.worktree = ntdata
# read the latest worktree data
wtdata = self.read_wtdata_from_base()
#
build_node_link(wtdata)
self.init_ctx(wtdata["ctx"])
self.ctx.nodes = wtdata["nodes"]
self.ctx.links = wtdata["links"]
self.ctx.ctrl_links = wtdata["ctrl_links"]
self.ctx.worktree = wtdata
print("init")
#
nc = ConnectivityAnalysis(ntdata)
nc = ConnectivityAnalysis(wtdata)
self.ctx.connectivity = nc.build_connectivity()
self.ctx.msgs = []
self.node.set_process_label(f"WorkTree: {self.ctx.worktree['name']}")
Expand All @@ -426,6 +416,13 @@ def setup(self):
# init node results
self.set_node_results()

def read_wtdata_from_base(self):
"""Read worktree data from base.extras."""
from aiida.orm.utils.serialize import deserialize_unsafe

wtdata = deserialize_unsafe(self.node.base.extras.get("worktree"))
return wtdata

def init_ctx(self, datas):
from aiida_worktree.utils import update_nested_dict

Expand Down Expand Up @@ -654,14 +651,14 @@ def run_nodes(self, names):
print("group outputs: ", executor.group_outputs)
wt.group_outputs = executor.group_outputs
wt.name = name
ntdata = wt.to_dict()
wtdata = wt.to_dict()
# merge the kwargs
merge_properties(ntdata)
all = {"nt": ntdata, "metadata": {"call_link_label": name}}
merge_properties(wtdata)
all = {"worktree": wtdata, "metadata": {"call_link_label": name}}
print("submit worktree: ")
process = self.submit(self.__class__, **all)
# save the ntdata to the process extras, so that we can load the worktree
process.base.extras.set("nt", serialize(ntdata))
# save the wtdata to the process extras, so that we can load the worktree
process.base.extras.set("worktree", serialize(wtdata))
node["process"] = process
# self.ctx.nodes[name]["group_outputs"] = executor.group_outputs
self.ctx.nodes[name]["state"] = "RUNNING"
Expand Down
File renamed without changes.
Empty file added aiida_worktree/utils/tree.py
Empty file.
94 changes: 71 additions & 23 deletions aiida_worktree/worktree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import node_graph
import aiida
from aiida_worktree.nodes import node_pool
import time


class WorkTree(node_graph.NodeGraph):
Expand Down Expand Up @@ -32,22 +33,35 @@ def __init__(self, name="WorkTree", **kwargs):
self.worktree_type = "NORMAL"
self.sequence = []
self.conditions = []
self.process = None

def run(self):
"""
Run the AiiDA worktree process and update the process status. The method uses AiiDA's engine to run
the process and then calls the update method to update the state of the process.
"""
from aiida_worktree.engine.worktree import WorkTree
from aiida_worktree.engine.worktree import WorkTree as WorkTreeEngine
from aiida_worktree.utils import merge_properties
from aiida.orm.utils.serialize import serialize

ntdata = self.to_dict()
merge_properties(ntdata)
all = {"nt": ntdata}
_result, self.process = aiida.engine.run_get_node(WorkTree, **all)
self.process.base.extras.set("nt", serialize(ntdata))
from aiida.manage import manager

# One can not run again if the process is alreay created. otherwise, a new process node will
# be created again.
if self.process is not None:
print("Your worktree is already created. Please use the submit() method.")
return
wtdata = self.to_dict()
merge_properties(wtdata)
inputs = {"worktree": wtdata}
# init a process
runner = manager.get_manager().get_runner()
process_inited = WorkTreeEngine(runner=runner, inputs=inputs)
self.process = process_inited.node
# save worktree data into process node
self.process.base.extras.set("worktree", serialize(wtdata))
result = aiida.engine.run(process_inited)
self.update()
return result

def submit(self, wait=False, timeout=60):
"""
Expand All @@ -57,30 +71,65 @@ def submit(self, wait=False, timeout=60):
wait (bool, optional): If True, the function will wait until the process finishes. Defaults to False.
timeout (int, optional): The maximum time in seconds to wait for the process to finish. Defaults to 60.
"""
from aiida_worktree.engine.worktree import WorkTree
from aiida_worktree.utils import merge_properties
from aiida.orm.utils.serialize import serialize
from aiida.engine.processes import control

ntdata = self.to_dict()
merge_properties(ntdata)
all = {"nt": ntdata}
self.process = aiida.engine.submit(WorkTree, **all)
#
self.process.base.extras.set("nt", serialize(ntdata))
self.save()
if self.process.process_state.value.upper() not in ["CREATED"]:
return "Error!!! The process has already been submitted and finished."
# TODO in case of "[ERROR] Process<3705> is unreachable."
control.play_processes([self.process])
if wait:
self.wait(timeout=timeout)

def save(self):
"""Save the udpated worktree to the process
This is only used for a running worktree.
Save the AiiDA worktree process and update the process status.
"""
from aiida_worktree.engine.worktree import WorkTree as WorkTreeEngine
from aiida_worktree.utils import merge_properties
from aiida.manage import manager

wtdata = self.to_dict()
merge_properties(wtdata)
inputs = {"worktree": wtdata}
runner = manager.get_manager().get_runner()
if self.process is None:
# init a process node
process_inited = WorkTreeEngine(runner=runner, inputs=inputs)
runner.persister.save_checkpoint(process_inited)
# return the future result
# future = runner.controller.continue_process(process_inited.pid)
self.process = process_inited.node
# start = time.time()
# while not future.done():
# time.sleep(1)
# if time.time() - start > 5:
# print("Worktree dosen't save properly.")
# return
# print(f"WorkTree node crated, PK: {self.process.pk}")
self.save_to_base(wtdata)
self.update()

def save_to_base(self, wtdata):
"""Save new wtdata to base.extras.
It will first check the difference, and reset nodes if needed.
"""
from aiida.orm.utils.serialize import serialize

self.process.base.extras.set("worktree", serialize(wtdata))

def to_dict(self):
ntdata = super().to_dict()
wtdata = super().to_dict()
self.ctx["sequence"] = self.sequence
# only alphanumeric and underscores are allowed
ntdata["ctx"] = {
wtdata["ctx"] = {
key.replace(".", "__"): value for key, value in self.ctx.items()
}
ntdata["worktree_type"] = self.worktree_type
ntdata["conditions"] = self.conditions
wtdata["worktree_type"] = self.worktree_type
wtdata["conditions"] = self.conditions

return ntdata
return wtdata

def wait(self, timeout=50):
"""
Expand All @@ -89,7 +138,6 @@ def wait(self, timeout=50):
Args:
timeout (int): The maximum time in seconds to wait for the process to finish. Defaults to 50.
"""
import time

start = time.time()
self.update()
Expand Down Expand Up @@ -144,7 +192,7 @@ def load(cls, pk):
from aiida.orm.utils.serialize import deserialize_unsafe

process = aiida.orm.load_node(pk)
wtdata = deserialize_unsafe(process.base.extras.get("nt"))
wtdata = deserialize_unsafe(process.base.extras.get("worktree"))
wt = cls.from_dict(wtdata)
wt.process = process
wt.update()
Expand Down
1 change: 1 addition & 0 deletions docs/source/howto/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ This section contains a collection of HowTos for various topics.
while
ctx
wait
restart
continue_finished_worktree
protocol
cli
Loading

0 comments on commit cc57989

Please sign in to comment.