Skip to content

Commit

Permalink
Ensure printv gets one string
Browse files Browse the repository at this point in the history
  • Loading branch information
mmarkakis committed Aug 1, 2024
1 parent e6cd9f5 commit 7c4b8aa
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "eccs"
version = "0.1.3"
version = "0.1.4"
authors = [
{ name="Markos Markakis", email="markakis@mit.edu"},
{ name="Sylvia Ziyu Zhang", email="sylziyuz@mit.edu"},
Expand Down
14 changes: 7 additions & 7 deletions src/eccs/eccs.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def _edit_and_get_ate(self, edits: list[EdgeEdit]) -> Optional[float]:

# Edit graph
for src, dst, edit_type in edits:
Printer.printv("Applying edit: ", src, dst, edit_type)
Printer.printv(f"Applying edit: {src}, {dst}, {edit_type}")
if edit_type == EdgeEditType.ADD:
if not graph.has_edge(src, dst):
graph.add_edge(src, dst)
Expand All @@ -492,12 +492,12 @@ def _edit_and_get_ate(self, edits: list[EdgeEdit]) -> Optional[float]:

# Compute the ATE if the graph is acceptable
if self._is_acceptable(graph):
Printer.printv("Graph is acceptable after edits: ", edits)
Printer.printv(f"Graph is acceptable after edits: {edits}")
ate = self.get_ate(graph)
Printer.printv("Got back ATE: ", ate)
Printer.printv(f"Got back ATE: {ate}")
return ate

Printer.printv("Graph is not acceptable after edits: ", edits)
Printer.printv(f"Graph is not acceptable after edits: {edits}")
return None

def _edit_and_draw(self, edits: list[EdgeEdit]) -> Optional[str]:
Expand Down Expand Up @@ -806,7 +806,7 @@ def maybe_update_ranking(v, edits, ate):
base_adj_set = ECCS._find_adjustment_set(
self._graph, self.treatment, self.outcome
)
Printer.printv("Found base adjustment set: ", base_adj_set)
Printer.printv(f"Found base adjustment set: {base_adj_set}")
vars_not_in_adj_set = [
v
for v in self.vars
Expand All @@ -826,15 +826,15 @@ def maybe_update_ranking(v, edits, ate):
for v in vars_not_in_adj_set:
Printer.printv(f"Trying to add {v} to the adjustment set")
edits = mapper.map_addition(v, use_optimized)
Printer.printv("Got back edits for addition: ", edits)
Printer.printv(f"Got back edits for addition: {edits}")
ate = self._edit_and_get_ate(edits)
maybe_update_ranking(v, edits, ate)

# Try removing each of the removable
for v in base_adj_set:
Printer.printv(f"Trying to remove {v} from the adjustment set")
edits = mapper.map_removal(v, use_optimized)
Printer.printv("Got back edit lists for removal: ", edits)
Printer.printv(f"Got back edit lists for removal: {edits}")
ate = self._edit_and_get_ate(edits)
maybe_update_ranking(v, edits, ate)

Expand Down
8 changes: 2 additions & 6 deletions src/eccs/heuristic_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,7 @@ def astar(self, k: int = 100):
neighbors = self._get_neighbors(current_node_id, frontier, n_lookahead)
if self._cur_next_id > self._computational_budget:
Printer.printv(
"Out of computational budget: ",
self._cur_next_id,
self._computational_budget,
f"Out of computational budget: {self._cur_next_id}, {self._computational_budget}"
)
break
for neighbor_id, edge_type in neighbors:
Expand All @@ -316,9 +314,7 @@ def astar(self, k: int = 100):

if self._cur_next_id > self._computational_budget:
Printer.printv(
"Out of computational budget: ",
self._cur_next_id,
self._computational_budget,
f"Out of computational budget: {self._cur_next_id}, {self._computational_budget}"
)
break
self._visited.add(current_node_id)
Expand Down

0 comments on commit 7c4b8aa

Please sign in to comment.