diff --git a/prototype/sky/optimizer.py b/prototype/sky/optimizer.py index f4aa5bbfeb6..e17b8c6dd85 100644 --- a/prototype/sky/optimizer.py +++ b/prototype/sky/optimizer.py @@ -231,7 +231,8 @@ def _optimize_cost( logger.debug(f'resources: {resources}') if minimize_cost: - estimated_cost = resources.get_cost(estimated_runtime) + cost_per_node = resources.get_cost(estimated_runtime) + estimated_cost = cost_per_node * node.num_nodes else: # Minimize run time; overload the term 'cost'. estimated_cost = estimated_runtime @@ -244,6 +245,7 @@ def _optimize_cost( ' estimated_cost (not incl. egress): ${:.1f}'. format(estimated_cost)) + # FIXME: Account for egress costs for multi-node clusters sum_parent_cost_and_egress = 0 for parent in parents: min_pred_cost_plus_egress = np.inf @@ -320,8 +322,9 @@ def _walk(node, best_hardware, best_cost): overall_best / 3600)) # Do not print Source or Sink. message_data = [ - t for t in message_data - if t[0].name not in (_DUMMY_SOURCE_NAME, _DUMMY_SINK_NAME) + (t, f'{t.num_nodes}x {repr(r)}' if t.num_nodes > 1 else repr(r)) + for (t, r) in message_data + if t.name not in (_DUMMY_SOURCE_NAME, _DUMMY_SINK_NAME) ] message = tabulate.tabulate(reversed(message_data), headers=['TASK', 'BEST_RESOURCE'], diff --git a/prototype/sky/task.py b/prototype/sky/task.py index 88aa82828aa..c660e530262 100644 --- a/prototype/sky/task.py +++ b/prototype/sky/task.py @@ -486,6 +486,8 @@ def __repr__(self): s += f'\n inputs: {self.inputs}' if self.outputs is not None: s += f'\n outputs: {self.outputs}' + if self.num_nodes > 1: + s += f'\n nodes: {self.num_nodes}' if len(self.resources) > 1 or not list(self.resources)[0].is_empty(): s += f'\n resources: {self.resources}' else: