diff --git a/aiida_bader/workchains/qe_bader.py b/aiida_bader/workchains/qe_bader.py index 1b088a5..4c4780a 100644 --- a/aiida_bader/workchains/qe_bader.py +++ b/aiida_bader/workchains/qe_bader.py @@ -6,7 +6,8 @@ from aiida.common import AttributeDict from aiida.engine import ToContext, WorkChain from aiida.plugins import CalculationFactory, WorkflowFactory - +from aiida_quantumespresso.common.types import ElectronicType, RestartType, SpinType +from aiida import orm PwBaseWorkChain = WorkflowFactory( "quantumespresso.pw.base" @@ -23,7 +24,7 @@ def define(cls, spec): """Define workflow specification.""" super(QeBaderWorkChain, cls).define(spec) - spec.expose_inputs(PwBaseWorkChain, namespace="pw_base") + spec.expose_inputs(PwBaseWorkChain, namespace="scf") spec.expose_inputs(PpCalculation, namespace="pp", exclude=["parent_folder"]) spec.expose_inputs( BaderCalculation, namespace="bader", exclude=["charge_density_folder"] @@ -31,7 +32,7 @@ def define(cls, spec): spec.outline(cls.run_pw, cls.run_pp, cls.run_bader, cls.return_results) - spec.expose_outputs(PwBaseWorkChain, namespace="pw_base") + spec.expose_outputs(PwBaseWorkChain, namespace="scf") spec.expose_outputs(PpCalculation, namespace="pp") spec.expose_outputs(BaderCalculation, namespace="bader") @@ -41,12 +42,59 @@ def define(cls, spec): 905, "ERROR_PARSING_BADER_OUTPUT", "Error while parsing bader output" ) + @classmethod + def get_builder_from_protocol( + cls, + pw_code, + pp_code, + bader_code, + structure, + protocol=None, + overrides=None, + options=None, + **kwargs + ): + """Return a builder prepopulated with inputs selected according to the chosen protocol. + + :param code: the ``Code`` instance configured for the ``quantumespresso.pw`` plugin. + :param structure: the ``StructureData`` instance to use. + :param protocol: protocol to use, if not specified, the default will be used. + :param overrides: optional dictionary of inputs to override the defaults of the protocol. + """ + + if isinstance(pw_code, str): + pw_code = orm.load_code(pw_code) + if isinstance(bader_code, str): + bader_code = orm.load_code(bader_code) + + inputs = cls.get_protocol_inputs(protocol, overrides) + + scf = PwBaseWorkChain.get_builder_from_protocol( + pw_code, structure, protocol, overrides=inputs.get('scf', None), + options=options, **kwargs + ) + pp = PpCalculation.get_builder_from_protocol( + pp_code, structure, protocol, overrides=inputs.get('pp', None), + options=options, **kwargs + ) + bader = BaderCalculation.get_builder_from_protocol( + bader_code, structure, protocol, overrides=inputs.get('bader', None), + options=options, **kwargs + ) + + builder = cls.get_builder() + builder.scf = scf + builder.pp = pp + builder.bader = bader + + return builder + def run_pw(self): """Run PW.""" - pw_base_inputs = AttributeDict(self.exposed_inputs(PwBaseWorkChain, "pw_base")) - pw_base_inputs["metadata"]["label"] = "pw_scf" - pw_base_inputs["metadata"]["call_link_label"] = "call_pw_scf" - running = self.submit(PwBaseWorkChain, **pw_base_inputs) + scf_inputs = AttributeDict(self.exposed_inputs(PwBaseWorkChain, "scf")) + scf_inputs["metadata"]["label"] = "pw_scf" + scf_inputs["metadata"]["call_link_label"] = "call_pw_scf" + running = self.submit(PwBaseWorkChain, **scf_inputs) self.report("Running PwBaseWorkChain.") return ToContext(pw_calc=running) @@ -95,7 +143,7 @@ def return_results(self): try: self.out_many( self.exposed_outputs( - self.ctx.pw_calc, PwBaseWorkChain, namespace="pw_base" + self.ctx.pw_calc, PwBaseWorkChain, namespace="scf" ) ) self.out_many(