Skip to content

Commit

Permalink
Ensembles: don't expect OptimizerResult.id to be convertible to int
Browse files Browse the repository at this point in the history
Fixes `Ensemble.from_optimization_{history,endpoints}`, which incorrectly assumed that `OptimizerResult.id` is always convertible to `int`.

Closes ICB-DCM#1349.
  • Loading branch information
dweindl committed Apr 2, 2024
1 parent ace9b8b commit e6e45d3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
10 changes: 5 additions & 5 deletions pypesto/ensemble/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from functools import partial
from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -480,7 +480,7 @@ def __init__(
self,
x_vectors: np.ndarray,
x_names: Sequence[str] = None,
vector_tags: Sequence[tuple[int, int]] = None,
vector_tags: Sequence[Any] = None,
ensemble_type: EnsembleType = None,
predictions: Sequence[EnsemblePrediction] = None,
lower_bound: np.ndarray = None,
Expand Down Expand Up @@ -522,7 +522,7 @@ def __init__(
self.x_vectors = x_vectors
self.n_x = x_vectors.shape[0]
self.n_vectors = x_vectors.shape[1]
self.vector_tags = vector_tags
self.vector_tags = list(vector_tags) if vector_tags is not None else []
self.summary = None

# store bounds
Expand Down Expand Up @@ -669,7 +669,7 @@ def from_optimization_endpoints(
x_vectors.append(start["x"][result.problem.x_free_indices])

# the vector tag will be a -1 to indicate it is the last step
vector_tags.append((int(start["id"]), -1))
vector_tags.append((start["id"], -1))
else:
break

Expand Down Expand Up @@ -800,7 +800,7 @@ def from_optimization_history(
x_vectors.extend([x_trace[start][ind] for ind in indices])
vector_tags.extend(
[
(int(result.optimize_result.list[start]["id"]), ind)
(result.optimize_result.list[start]["id"], ind)
for ind in indices
]
)
Expand Down
4 changes: 2 additions & 2 deletions test/base/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ def test_ensemble_from_optimization():

# compare vector_tags with the expected values:
ep_tags = [
(int(result.optimize_result.list[i]["id"]), -1) for i in [0, 1, 2, 3]
(result.optimize_result.list[i]["id"], -1) for i in [0, 1, 2, 3]
]

hist_tags = [
(
int(result.optimize_result.list[i]["id"]),
result.optimize_result.list[i]["id"],
len(result.optimize_result.list[i]["history"]._trace["fval"])
- 1
- j,
Expand Down

0 comments on commit e6e45d3

Please sign in to comment.