Skip to content

Commit

Permalink
add get_builder_from_protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Oct 12, 2023
1 parent 7765109 commit 8f820e7
Showing 1 changed file with 56 additions and 8 deletions.
64 changes: 56 additions & 8 deletions aiida_bader/workchains/qe_bader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from aiida.common import AttributeDict
from aiida.engine import ToContext, WorkChain
from aiida.plugins import CalculationFactory, WorkflowFactory

from aiida_quantumespresso.common.types import ElectronicType, RestartType, SpinType
from aiida import orm

PwBaseWorkChain = WorkflowFactory(
"quantumespresso.pw.base"
Expand All @@ -23,15 +24,15 @@ def define(cls, spec):
"""Define workflow specification."""
super(QeBaderWorkChain, cls).define(spec)

spec.expose_inputs(PwBaseWorkChain, namespace="pw_base")
spec.expose_inputs(PwBaseWorkChain, namespace="scf")
spec.expose_inputs(PpCalculation, namespace="pp", exclude=["parent_folder"])
spec.expose_inputs(
BaderCalculation, namespace="bader", exclude=["charge_density_folder"]
)

spec.outline(cls.run_pw, cls.run_pp, cls.run_bader, cls.return_results)

spec.expose_outputs(PwBaseWorkChain, namespace="pw_base")
spec.expose_outputs(PwBaseWorkChain, namespace="scf")
spec.expose_outputs(PpCalculation, namespace="pp")
spec.expose_outputs(BaderCalculation, namespace="bader")

Expand All @@ -41,12 +42,59 @@ def define(cls, spec):
905, "ERROR_PARSING_BADER_OUTPUT", "Error while parsing bader output"
)

@classmethod
def get_builder_from_protocol(
cls,
pw_code,
pp_code,
bader_code,
structure,
protocol=None,
overrides=None,
options=None,
**kwargs
):
"""Return a builder prepopulated with inputs selected according to the chosen protocol.
:param code: the ``Code`` instance configured for the ``quantumespresso.pw`` plugin.
:param structure: the ``StructureData`` instance to use.
:param protocol: protocol to use, if not specified, the default will be used.
:param overrides: optional dictionary of inputs to override the defaults of the protocol.
"""

if isinstance(pw_code, str):
pw_code = orm.load_code(pw_code)
if isinstance(bader_code, str):
bader_code = orm.load_code(bader_code)

inputs = cls.get_protocol_inputs(protocol, overrides)

scf = PwBaseWorkChain.get_builder_from_protocol(
pw_code, structure, protocol, overrides=inputs.get('scf', None),
options=options, **kwargs
)
pp = PpCalculation.get_builder_from_protocol(
pp_code, structure, protocol, overrides=inputs.get('pp', None),
options=options, **kwargs
)
bader = BaderCalculation.get_builder_from_protocol(
bader_code, structure, protocol, overrides=inputs.get('bader', None),
options=options, **kwargs
)

builder = cls.get_builder()
builder.scf = scf
builder.pp = pp
builder.bader = bader

return builder

def run_pw(self):
"""Run PW."""
pw_base_inputs = AttributeDict(self.exposed_inputs(PwBaseWorkChain, "pw_base"))
pw_base_inputs["metadata"]["label"] = "pw_scf"
pw_base_inputs["metadata"]["call_link_label"] = "call_pw_scf"
running = self.submit(PwBaseWorkChain, **pw_base_inputs)
scf_inputs = AttributeDict(self.exposed_inputs(PwBaseWorkChain, "scf"))
scf_inputs["metadata"]["label"] = "pw_scf"
scf_inputs["metadata"]["call_link_label"] = "call_pw_scf"
running = self.submit(PwBaseWorkChain, **scf_inputs)
self.report("Running PwBaseWorkChain.")
return ToContext(pw_calc=running)

Expand Down Expand Up @@ -95,7 +143,7 @@ def return_results(self):
try:
self.out_many(
self.exposed_outputs(
self.ctx.pw_calc, PwBaseWorkChain, namespace="pw_base"
self.ctx.pw_calc, PwBaseWorkChain, namespace="scf"
)
)
self.out_many(
Expand Down

0 comments on commit 8f820e7

Please sign in to comment.