Skip to content

Commit

Permalink
Add mi.variant_context()
Browse files Browse the repository at this point in the history
  • Loading branch information
leroyvn authored and njroussel committed May 9, 2023
1 parent 9273af4 commit 96b219d
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/python/python/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .util import traverse, SceneParameters, render, cornell_box
from .util import traverse, SceneParameters, render, cornell_box, variant_context
from . import chi2
from . import xml
from . import ad
27 changes: 27 additions & 0 deletions src/python/python/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,30 @@ def test05_render_fwd_assert(variants_all_ad_rgb):
with pytest.raises(Exception) as e:
img = mi.render(scene)
dr.forward_to(img)


def test06_variant_context():
# Select the first variant which is not 'scalar_rgb'
for variant in mi.variants():
if variant != "scalar_rgb":
override_variant = variant
break
else:
pytest.skip("Only the 'scalar_rgb' variant was compiled.")

# Now, the test
mi.set_variant("scalar_rgb")

# The active variant is temporarily overridden
with mi.variant_context(override_variant):
assert mi.variant() == override_variant
assert mi.variant() == "scalar_rgb"

# The initial variant is restored if an exception is raised
try:
with mi.variant_context(override_variant):
assert mi.variant() == override_variant
raise RuntimeError
except RuntimeError:
pass
assert mi.variant() == "scalar_rgb"
18 changes: 18 additions & 0 deletions src/python/python/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations as __annotations__ # Delayed parsing of type annotations

import contextlib
from collections.abc import Mapping

import drjit as dr
Expand Down Expand Up @@ -696,3 +697,20 @@ def cornell_box():
}
},
}


@contextlib.contextmanager
def variant_context(*args) -> None:
'''
Temporarily override the active variant. Arguments are interpreted as
they are in :func:`mitsuba.set_variant`.
'''

old_variant = mi.variant()
try:
mi.set_variant(*args)
yield
except Exception:
raise
finally:
mi.set_variant(old_variant)

0 comments on commit 96b219d

Please sign in to comment.