diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 5406aad172c4..2d27bf064fce 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -2062,11 +2062,9 @@ def process_primitive(self, primitive, tracers, params): def default_process_primitive(self, primitive, tracers, params): avals = [t.aval for t in tracers] out_avals, effects = primitive.abstract_eval(*avals, **params) - # == serve as a "not xor" here. - if not (isinstance(out_avals, (tuple,list)) == primitive.multiple_results): - raise ValueError(f"{primitive}.abstract_eval() method should return" - f" a tuple or a list if {primitive}.multiple_results" - " is true. Otherwise it shouldn't.") + if isinstance(out_avals, (tuple, list)) != primitive.multiple_results: + raise ValueError(f"{primitive}.abstract_eval() method should return " + f"a tuple or a list iff {primitive}.multiple_results.") out_avals = [out_avals] if not primitive.multiple_results else out_avals source_info = source_info_util.current() out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]