Skip to content

Commit

Permalink
Add trampoline for AdjointIntegrator
Browse files Browse the repository at this point in the history
  • Loading branch information
njroussel committed Mar 20, 2023
1 parent 85bf3c6 commit c4a8b31
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 5 deletions.
4 changes: 4 additions & 0 deletions docs/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ Being an experimental research framework, Mitsuba 3 does not strictly follow the
`Semantic Versioning <https://semver.org/>`_ convention. That said, we will
strive to document breaking API changes in the release notes below.

Incoming changes
----------------

- Allow extending ``AdjointIntegrator`` in Python

Mitsuba 3.2.1
-------------
Expand Down
13 changes: 13 additions & 0 deletions src/integrators/tests/test_ptracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,16 @@ def test06_ptracer_gradients(variants_all_ad_rgb):
g = dr.grad(params[key])
assert dr.shape(g) == dr.shape(params[key])
assert dr.allclose(g, 0.33647, atol=1e-5) # TODO improve this test (don't use hardcoded value)


def test07_adjoint_integrator_trampoline(variants_all_ad_rgb):

class MyPtracer(mi.AdjointIntegrator):
def __init__(self, props=mi.Properties()):
super().__init__(props)

mi.register_integrator("myptracer", MyPtracer)

mi.load_dict({
'type': 'myptracer'
})
59 changes: 54 additions & 5 deletions src/render/python/integrator_v.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ ScopedSignalHandler::~ScopedSignalHandler() {
/// Trampoline for derived types implemented in Python
MI_VARIANT class PySamplingIntegrator : public SamplingIntegrator<Float, Spectrum> {
public:
MI_IMPORT_TYPES(SamplingIntegrator, Scene, Sensor, Sampler, Medium, Emitter, EmitterPtr, BSDF, BSDFPtr)
MI_IMPORT_TYPES(SamplingIntegrator, Scene, Sensor, Sampler, Medium)

PySamplingIntegrator(const Properties &props) : SamplingIntegrator(props) {
if constexpr (!dr::is_jit_v<Float>) {
Expand All @@ -71,11 +71,10 @@ MI_VARIANT class PySamplingIntegrator : public SamplingIntegrator<Float, Spectru
py::gil_scoped_acquire gil;
py::function render_override = py::get_override(this, "render");

if (render_override) {
if (render_override)
return render_override(scene, sensor, seed, spp, develop, evaluate).template cast<TensorXf>();
} else {
else
return SamplingIntegrator::render(scene, sensor, seed, spp, develop, evaluate);
}
}

std::pair<Spectrum, Mask> sample(const Scene *scene,
Expand Down Expand Up @@ -109,6 +108,54 @@ MI_VARIANT class PySamplingIntegrator : public SamplingIntegrator<Float, Spectru
}
};

/// Trampoline for derived types implemented in Python
MI_VARIANT class PyAdjointIntegrator : public AdjointIntegrator<Float, Spectrum> {
public:
MI_IMPORT_TYPES(AdjointIntegrator, Scene, Sensor, Sampler, ImageBlock)

PyAdjointIntegrator(const Properties &props) : AdjointIntegrator(props) {
if constexpr (!dr::is_jit_v<Float>) {
Log(Warn, "AdjointIntegrator Python implementations will have "
"terrible performance in scalar_* modes. It is strongly "
"recommended to switch to a cuda_* or llvm_* mode");
}
}

TensorXf render(Scene *scene,
Sensor *sensor,
uint32_t seed,
uint32_t spp,
bool develop,
bool evaluate) override {
py::gil_scoped_acquire gil;
py::function render_override = py::get_override(this, "render");

if (render_override)
return render_override(scene, sensor, seed, spp, develop, evaluate).template cast<TensorXf>();
else
return AdjointIntegrator::render(scene, sensor, seed, spp, develop, evaluate);
}

void sample(const Scene *scene, const Sensor *sensor, Sampler *sampler,
ImageBlock *block, ScalarFloat sample_scale) const override {
py::gil_scoped_acquire gil;
py::function sample_override = py::get_override(this, "sample");

if (sample_override)
sample_override(scene, sensor, sampler, block, sample_scale);
else
Throw("AdjointIntegrator doesn't overload the method \"sample\"");
}

std::vector<std::string> aov_names() const override {
PYBIND11_OVERRIDE(std::vector<std::string>, AdjointIntegrator, aov_names, );
}

std::string to_string() const override {
PYBIND11_OVERRIDE(std::string, AdjointIntegrator, to_string, );
}
};

/**
* \brief Abstract integrator that should **exclusively** be used to trampoline
* Python AD integrators for primal renderings
Expand Down Expand Up @@ -202,6 +249,7 @@ MI_VARIANT class PyADIntegrator : public CppADIntegrator<Float, Spectrum> {
MI_PY_EXPORT(Integrator) {
MI_PY_IMPORT_TYPES()
using PySamplingIntegrator = PySamplingIntegrator<Float, Spectrum>;
using PyAdjointIntegrator = PyAdjointIntegrator<Float, Spectrum>;
using CppADIntegrator = CppADIntegrator<Float, Spectrum>;
using PyADIntegrator = PyADIntegrator<Float, Spectrum>;

Expand Down Expand Up @@ -257,7 +305,8 @@ MI_PY_EXPORT(Integrator) {
PyADIntegrator>(m, "CppADIntegrator")
.def(py::init<const Properties &>());

MI_PY_CLASS(AdjointIntegrator, Integrator)
MI_PY_TRAMPOLINE_CLASS(PyAdjointIntegrator, AdjointIntegrator, Integrator)
.def(py::init<const Properties &>())
.def_method(AdjointIntegrator, sample, "scene"_a, "sensor"_a,
"sampler"_a, "block"_a, "sample_scale"_a);
}

0 comments on commit c4a8b31

Please sign in to comment.