Skip to content

Commit

Permalink
Merge pull request #23642 from mattjj:tweak-error-logic
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 674973797
  • Loading branch information
Google-ML-Automation committed Sep 16, 2024
2 parents 839ce9a + 02bb3d1 commit a8b996a
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2062,11 +2062,9 @@ def process_primitive(self, primitive, tracers, params):
def default_process_primitive(self, primitive, tracers, params):
avals = [t.aval for t in tracers]
out_avals, effects = primitive.abstract_eval(*avals, **params)
# == serve as a "not xor" here.
if not (isinstance(out_avals, (tuple,list)) == primitive.multiple_results):
raise ValueError(f"{primitive}.abstract_eval() method should return"
f" a tuple or a list if {primitive}.multiple_results"
" is true. Otherwise it shouldn't.")
if isinstance(out_avals, (tuple, list)) != primitive.multiple_results:
raise ValueError(f"{primitive}.abstract_eval() method should return "
f"a tuple or a list iff {primitive}.multiple_results.")
out_avals = [out_avals] if not primitive.multiple_results else out_avals
source_info = source_info_util.current()
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
Expand Down

0 comments on commit a8b996a

Please sign in to comment.