Skip to content

Commit

Permalink
Store only needed one band bxsfs
Browse files Browse the repository at this point in the history
  * In wan2skeaf.jl write bxsfs only for bands that cross Fermi energy (FE)
  * The workflow stores only the needed bxsfs as output RemoteData nodes
  * Allow passing a custom Fermi energy to wan2skeaf.jl
  * New parameter `fermi_energy` to workchain.inputs.wan2skeaf
  * Modify `validate_inputs` function of the SkeafWorkChain to check that `fermi_energy` is passed correctly
      - if FE is specified only in skeaf parameters, it is passed also to wan2skeaf parameters
      - error if FE specified only in wan2skeaf or values in skeaf and wan2skeaf parameters are different
  * New entries `custom_fermi_energy` and `bands_crossing_fermi` in wan2skeaf.outputs.output_parameters

Set default for `fermi_energy` parameter in wan2skeaf.jl to "none"
  • Loading branch information
npaulish committed Jul 9, 2024
1 parent 027898d commit 46ec160
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 18 deletions.
3 changes: 3 additions & 0 deletions aiida_skeaf/calculations/wan2skeaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def prepare_for_submission(self, folder):
cmdline_params += ["-p", parameters["occupation_prefactor"]]
if "tol_n_electrons" in parameters:
cmdline_params += ["-t", parameters["tol_n_electrons"]]
if "fermi_energy" in parameters:
cmdline_params += ["-f", parameters["fermi_energy"]]

cmdline_params.append(self.inputs.bxsf_filename.value)
#
Expand Down Expand Up @@ -190,6 +192,7 @@ def prepare_for_submission(self, folder):
Optional("smearing_value"): float,
Optional("occupation_prefactor"): int,
Optional("tol_n_electrons"): float,
Optional("fermi_energy"): float,
}


Expand Down
18 changes: 14 additions & 4 deletions aiida_skeaf/parsers/wan2skeaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,24 +94,24 @@ def parse(self, **kwargs):

self.out("output_parameters", output_node)

band_indexes_in_bxsf = output_node.get_dict().get("band_indexes_in_bxsf")
bands_crossing_fermi = output_node.get_dict().get("bands_crossing_fermi")

# attach RemoteData for extracted bxsf
self.logger.info("Attaching extracted bxsf files")
self.attach_bxsf_files(band_indexes_in_bxsf)
self.attach_bxsf_files(bands_crossing_fermi)

return ExitCode(0)

def attach_bxsf_files( # pylint: disable=inconsistent-return-statements
self, band_indexes_in_bxsf
self, bands_crossing_fermi
):
"""Attach RemoteData for extracted bxsf."""

input_params = self.node.inputs["parameters"].get_dict()
input_band_index = input_params.get("band_index", -1)

if input_band_index == -1:
indexes = band_indexes_in_bxsf
indexes = bands_crossing_fermi
else:
indexes = [input_band_index]

Expand Down Expand Up @@ -192,6 +192,11 @@ def parse_wan2skeaf_out(filecontent: ty.List[str]) -> orm.Dict:
r"Final tolerance for number of electrons:\s*([+-]?(?:[0-9]*[.])?[0-9]+e?[+-]?[0-9]*)"
),
"band_indexes_in_bxsf": re.compile(r"Bands in bxsf:\s*(.+)"),
"custom_fermi_energy": re.compile(
r"Custom Fermi energy will be used to select the bands "
+ r"that are written to separate bxsfs:\s*([+-]?(?:[0-9]*[.])?[0-9]+)"
),
"bands_crossing_fermi": re.compile(r"Bands crossing Fermi energy:\s*(.+)"),
"timestamp_end": re.compile(r"Job done at\s*(.+)"),
}
re_band_minmax = re.compile(
Expand Down Expand Up @@ -233,6 +238,9 @@ def parse_wan2skeaf_out(filecontent: ty.List[str]) -> orm.Dict:
parameters["band_indexes_in_bxsf"] = [
int(_) for _ in parameters["band_indexes_in_bxsf"].split()
]
parameters["bands_crossing_fermi"] = [
int(_) for _ in parameters["bands_crossing_fermi"].split()
]
float_keys = [
"smearing_width",
"tol_n_electrons_initial",
Expand All @@ -250,6 +258,8 @@ def parse_wan2skeaf_out(filecontent: ty.List[str]) -> orm.Dict:
parameters["num_bands"] = int(parameters["num_bands"])
parameters["num_electrons"] = int(parameters["num_electrons"])
parameters["occupation_prefactor"] = int(parameters["occupation_prefactor"])
if "custom_fermi_energy" in parameters:
parameters["custom_fermi_energy"] = float(parameters["custom_fermi_energy"])

# make sure the order is the same as parameters["band_indexes_in_bxsf"]
parameters["band_min"] = [
Expand Down
28 changes: 24 additions & 4 deletions aiida_skeaf/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,25 @@
__all__ = ["validate_inputs", "SkeafWorkChain"]


def validate_inputs(
inputs: AttributeDict, ctx=None # pylint: disable=unused-argument
) -> None:
def validate_inputs( # pylint: disable=inconsistent-return-statements
inputs: AttributeDict, ctx=None
): # pylint: disable=unused-argument
"""Validate the inputs of the entire input namespace."""
# pylint: disable=no-member

if "fermi_energy" in inputs["skeaf"]["parameters"]:
if "fermi_energy" not in inputs["wan2skeaf"]["parameters"]:
inputs["wan2skeaf"]["parameters"]["fermi_energy"] = inputs["skeaf"][
"parameters"
]["fermi_energy"]
else:
if (
inputs["wan2skeaf"]["parameters"]["fermi_energy"]
!= inputs["skeaf"]["parameters"]["fermi_energy"]
):
return SkeafWorkChain.exit_codes.ERROR_INVALID_INPUT_FERMI.message
elif "fermi_energy" in inputs["wan2skeaf"]["parameters"]:
return SkeafWorkChain.exit_codes.ERROR_INVALID_INPUT_FERMI.message


class SkeafWorkChain(ProtocolMixin, WorkChain):
Expand Down Expand Up @@ -92,6 +107,11 @@ def define(cls, spec) -> None:
"ERROR_SUB_PROCESS_FAILED_SKEAF",
message="Unrecoverable error when running skeaf.",
)
spec.exit_code(
500,
"ERROR_INVALID_INPUT_FERMI",
message="Invalid input parameters. Fermi energy is not consistent between skeaf and wan2skeaf.",
)

@classmethod
def get_protocol_filepath(cls) -> pathlib.Path:
Expand All @@ -103,7 +123,7 @@ def get_protocol_filepath(cls) -> pathlib.Path:
return files(protocols) / "skeaf.yaml"

@classmethod
def get_builder_from_protocol( # pylint: disable=too-many-statements
def get_builder_from_protocol( # pylint: disable=too-many-statements, too-many-arguments
cls,
codes: ty.Dict[str, ty.Union[orm.Code, str, int]],
*,
Expand Down
41 changes: 31 additions & 10 deletions utils/wan2skeaf.jl/wan2skeaf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,23 @@ The script
- `-n, --num_electrons`: number of electrons
- `-b, --band_index`: band index, default is -1 (all bands)
- `-o, --out_filename`: output filename prefix
- `-s, --smearing_type`: smearing type, default is `NoneSmearing()` (no smearing)
- `-s, --smearing_type`: smearing type, default is "none" (corresponding to `NoneSmearing()`, no smearing), other options are "fermi-dirac" or "fd", "marzari-vanderbilt" or "cold"
- `-w, --width_smearing`: smearing width, default is 0.0
- `-p, --prefactor`: occupation prefactor, 2 for non SOC, 1 for SOC, default is 2
- `-t, --tol_n_electrons`: tolerance for number of electrons, default is 1e-6
- `-f, --fermi_energy`: custom Fermi energy, default is none
"""
@main function main(bxsf::String; num_electrons::Int, band_index::Int=-1, out_filename::String="skeaf", smearing_type::String="none", width_smearing::Float64=0.0, prefactor::Int=2, tol_n_electrons::Float64=1e-6)
@main function main(
bxsf::String;
num_electrons::Int,
band_index::Int=-1,
out_filename::String="skeaf",
smearing_type::String="none",
width_smearing::Float64=0.0,
prefactor::Int=2,
tol_n_electrons::Float64=1e-6,
fermi_energy::String="none"
)
println("Started on ", Dates.now())
if !isfile(bxsf)
println("ERROR: input file $bxsf does not exist.")
Expand Down Expand Up @@ -82,6 +93,10 @@ The script
println("Smearing width: ", width_smearing)
println("Occupation prefactor: ", prefactor)
println("Initial tolerance for number of electrons (default 1e-6): ", tol_n_electrons)
parsed_fermi_energy = fermi_energy == "none" ? nothing : tryparse(Float64, fermi_energy)
if !isnothing(parsed_fermi_energy)
println("Custom Fermi energy will be used to select the bands that are written to separate bxsfs: ", parsed_fermi_energy)
end

# some times, w/o smearing, the number of electrons cannot be integrated to
# the exact number of electrons, since we only have discrete eigenvalues.
Expand Down Expand Up @@ -155,20 +170,26 @@ The script
band_range = [band_index]
end
println("Bands in bxsf: ", join([string(_) for _ in band_range], " "))
bands_crossing_fermi = zeros(Int,0)
for ib in band_range
# here I am still using the Fermi energy from input bxsf, i.e., QE scf Fermi
outfile = out_filename * "_band_$(ib).bxsf"
band_min = minimum(bxsf.E[ib:ib, :, :, :])
band_max = maximum(bxsf.E[ib:ib, :, :, :])
println("Min and max of band $ib : $band_min $band_max")
#if (bxsf.fermi_energy >= band_min && bxsf.fermi_energy <= band_max)
E_band_Ry = bxsf.E[ib:ib, :, :, :].*(ELECTRONVOLT_SI/RYDBERG_SI)
E_fermi_Ry = bxsf.fermi_energy*(ELECTRONVOLT_SI/RYDBERG_SI)
span_vectors_bohr = bxsf.span_vectors.*BOHR_TO_ANG/2/pi
# what about the origin? It has to be zero (Gamma point) for bxsf so I don't change it here
WannierIO.write_bxsf(outfile, E_fermi_Ry, bxsf.origin, span_vectors_bohr, E_band_Ry)
#end
end

# Check if the Fermi energy (could be custom!) is between the band_min and band_max
# only then write the file
ϵF = isnothing(parsed_fermi_energy) ? εF_bxsf : parsed_fermi_energy
if (ϵF >= band_min && ϵF <= band_max)
push!(bands_crossing_fermi, ib)
E_band_Ry = bxsf.E[ib:ib, :, :, :].*(ELECTRONVOLT_SI/RYDBERG_SI)
E_fermi_Ry = bxsf.fermi_energy*(ELECTRONVOLT_SI/RYDBERG_SI)
span_vectors_bohr = bxsf.span_vectors.*BOHR_TO_ANG/2/pi
# what about the origin? It has to be zero (Gamma point) for bxsf so I don't change it here
WannierIO.write_bxsf(outfile, E_fermi_Ry, bxsf.origin, span_vectors_bohr, E_band_Ry)
end
end
println("Bands crossing Fermi energy: ", join([string(_) for _ in bands_crossing_fermi], " "))
println("Job done at ", Dates.now())
end

0 comments on commit 46ec160

Please sign in to comment.