From 83d902384a32db654dd0291ffbf1b7832eb137cd Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Sat, 8 Jun 2024 13:43:32 +0200 Subject: [PATCH] Fix workflow attribute fallbacks. --- law/workflow/base.py | 2 +- law/workflow/remote.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/law/workflow/base.py b/law/workflow/base.py index d4a74a16..fe148d41 100644 --- a/law/workflow/base.py +++ b/law/workflow/base.py @@ -76,7 +76,7 @@ def __init__(self, *args, **kwargs): self._workflow_has_reset_branch_map = False - def _get_task_attribute(self, name, fallback=True): + def _get_task_attribute(self, name, fallback=False): """ Return an attribute of the actual task named ``_``. When the attribute does not exist and *fallback* is *True*, try to return the task attribute simply named diff --git a/law/workflow/remote.py b/law/workflow/remote.py index b15d0e9d..d45c1d9b 100644 --- a/law/workflow/remote.py +++ b/law/workflow/remote.py @@ -617,7 +617,7 @@ def dump_job_data(self): logger.debug("job data dumped") def get_run_context(self): - return self._get_task_attribute("workflow_run_context")() + return self._get_task_attribute("workflow_run_context", fallback=True)() def run(self): with self.get_run_context(): @@ -698,7 +698,7 @@ def _run_impl(self): # sleep once to give the job interface time to register the jobs if not self._submitted and not task.no_poll: - post_submit_delay = self._get_task_attribute("post_submit_delay")() + post_submit_delay = self._get_task_attribute("post_submit_delay", fallback=True)() if post_submit_delay > 0: logger.debug("sleep for {} second(s) due to post_submit_delay".format( post_submit_delay)) @@ -935,7 +935,7 @@ def _submit_batch(self, submit_jobs, **kwargs): job_files = [f["job"] for f in six.itervalues(all_job_files)] # prepare objects for dumping intermediate job data - dump_freq = self._get_task_attribute("dump_intermediate_job_data")() + dump_freq = self._get_task_attribute("dump_intermediate_job_data", fallback=True)() if dump_freq and not is_number(dump_freq): dump_freq = 50 @@ -1306,7 +1306,7 @@ def poll(self): raise Exception(err.format(self.poll_data.n_finished_min, n_jobs, n_failed)) # invoke the poll callback - poll_callback_res = self._get_task_attribute("poll_callback")(self.poll_data) + poll_callback_res = self._get_task_attribute("poll_callback", fallback=True)(self.poll_data) # noqa if poll_callback_res is False: logger.debug( "job polling loop gracefully stopped due to False returned by poll_callback",