Skip to content

Commit

Permalink
Fix workflow attribute fallbacks.
Browse files Browse the repository at this point in the history
  • Loading branch information
riga committed Jun 8, 2024
1 parent c475ebb commit 83d9023
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion law/workflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, *args, **kwargs):

self._workflow_has_reset_branch_map = False

def _get_task_attribute(self, name, fallback=True):
def _get_task_attribute(self, name, fallback=False):
"""
Return an attribute of the actual task named ``<workflow_type>_<name>``. When the attribute
does not exist and *fallback* is *True*, try to return the task attribute simply named
Expand Down
8 changes: 4 additions & 4 deletions law/workflow/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def dump_job_data(self):
logger.debug("job data dumped")

def get_run_context(self):
return self._get_task_attribute("workflow_run_context")()
return self._get_task_attribute("workflow_run_context", fallback=True)()

def run(self):
with self.get_run_context():
Expand Down Expand Up @@ -698,7 +698,7 @@ def _run_impl(self):

# sleep once to give the job interface time to register the jobs
if not self._submitted and not task.no_poll:
post_submit_delay = self._get_task_attribute("post_submit_delay")()
post_submit_delay = self._get_task_attribute("post_submit_delay", fallback=True)()
if post_submit_delay > 0:
logger.debug("sleep for {} second(s) due to post_submit_delay".format(
post_submit_delay))
Expand Down Expand Up @@ -935,7 +935,7 @@ def _submit_batch(self, submit_jobs, **kwargs):
job_files = [f["job"] for f in six.itervalues(all_job_files)]

# prepare objects for dumping intermediate job data
dump_freq = self._get_task_attribute("dump_intermediate_job_data")()
dump_freq = self._get_task_attribute("dump_intermediate_job_data", fallback=True)()
if dump_freq and not is_number(dump_freq):
dump_freq = 50

Expand Down Expand Up @@ -1306,7 +1306,7 @@ def poll(self):
raise Exception(err.format(self.poll_data.n_finished_min, n_jobs, n_failed))

# invoke the poll callback
poll_callback_res = self._get_task_attribute("poll_callback")(self.poll_data)
poll_callback_res = self._get_task_attribute("poll_callback", fallback=True)(self.poll_data) # noqa
if poll_callback_res is False:
logger.debug(
"job polling loop gracefully stopped due to False returned by poll_callback",
Expand Down

0 comments on commit 83d9023

Please sign in to comment.