From c4a8b31ee764a0e6d56d9075708c3c76062854be Mon Sep 17 00:00:00 2001 From: Nicolas Roussel Date: Mon, 20 Mar 2023 18:08:09 +0100 Subject: [PATCH] Add trampoline for AdjointIntegrator --- docs/release_notes.rst | 4 ++ src/integrators/tests/test_ptracer.py | 13 ++++++ src/render/python/integrator_v.cpp | 59 ++++++++++++++++++++++++--- 3 files changed, 71 insertions(+), 5 deletions(-) diff --git a/docs/release_notes.rst b/docs/release_notes.rst index 68b3721a7..1ea3e8892 100644 --- a/docs/release_notes.rst +++ b/docs/release_notes.rst @@ -5,6 +5,10 @@ Being an experimental research framework, Mitsuba 3 does not strictly follow the `Semantic Versioning `_ convention. That said, we will strive to document breaking API changes in the release notes below. +Incoming changes +---------------- + +- Allow extending ``AdjointIntegrator`` in Python Mitsuba 3.2.1 ------------- diff --git a/src/integrators/tests/test_ptracer.py b/src/integrators/tests/test_ptracer.py index 5a9b1a821..1b97105c2 100644 --- a/src/integrators/tests/test_ptracer.py +++ b/src/integrators/tests/test_ptracer.py @@ -223,3 +223,16 @@ def test06_ptracer_gradients(variants_all_ad_rgb): g = dr.grad(params[key]) assert dr.shape(g) == dr.shape(params[key]) assert dr.allclose(g, 0.33647, atol=1e-5) # TODO improve this test (don't use hardcoded value) + + +def test07_adjoint_integrator_trampoline(variants_all_ad_rgb): + + class MyPtracer(mi.AdjointIntegrator): + def __init__(self, props=mi.Properties()): + super().__init__(props) + + mi.register_integrator("myptracer", MyPtracer) + + mi.load_dict({ + 'type': 'myptracer' + }) diff --git a/src/render/python/integrator_v.cpp b/src/render/python/integrator_v.cpp index a6257117c..c7a7c9aee 100644 --- a/src/render/python/integrator_v.cpp +++ b/src/render/python/integrator_v.cpp @@ -52,7 +52,7 @@ ScopedSignalHandler::~ScopedSignalHandler() { /// Trampoline for derived types implemented in Python MI_VARIANT class PySamplingIntegrator : public SamplingIntegrator { public: - MI_IMPORT_TYPES(SamplingIntegrator, Scene, Sensor, Sampler, Medium, Emitter, EmitterPtr, BSDF, BSDFPtr) + MI_IMPORT_TYPES(SamplingIntegrator, Scene, Sensor, Sampler, Medium) PySamplingIntegrator(const Properties &props) : SamplingIntegrator(props) { if constexpr (!dr::is_jit_v) { @@ -71,11 +71,10 @@ MI_VARIANT class PySamplingIntegrator : public SamplingIntegrator(); - } else { + else return SamplingIntegrator::render(scene, sensor, seed, spp, develop, evaluate); - } } std::pair sample(const Scene *scene, @@ -109,6 +108,54 @@ MI_VARIANT class PySamplingIntegrator : public SamplingIntegrator { +public: + MI_IMPORT_TYPES(AdjointIntegrator, Scene, Sensor, Sampler, ImageBlock) + + PyAdjointIntegrator(const Properties &props) : AdjointIntegrator(props) { + if constexpr (!dr::is_jit_v) { + Log(Warn, "AdjointIntegrator Python implementations will have " + "terrible performance in scalar_* modes. It is strongly " + "recommended to switch to a cuda_* or llvm_* mode"); + } + } + + TensorXf render(Scene *scene, + Sensor *sensor, + uint32_t seed, + uint32_t spp, + bool develop, + bool evaluate) override { + py::gil_scoped_acquire gil; + py::function render_override = py::get_override(this, "render"); + + if (render_override) + return render_override(scene, sensor, seed, spp, develop, evaluate).template cast(); + else + return AdjointIntegrator::render(scene, sensor, seed, spp, develop, evaluate); + } + + void sample(const Scene *scene, const Sensor *sensor, Sampler *sampler, + ImageBlock *block, ScalarFloat sample_scale) const override { + py::gil_scoped_acquire gil; + py::function sample_override = py::get_override(this, "sample"); + + if (sample_override) + sample_override(scene, sensor, sampler, block, sample_scale); + else + Throw("AdjointIntegrator doesn't overload the method \"sample\""); + } + + std::vector aov_names() const override { + PYBIND11_OVERRIDE(std::vector, AdjointIntegrator, aov_names, ); + } + + std::string to_string() const override { + PYBIND11_OVERRIDE(std::string, AdjointIntegrator, to_string, ); + } +}; + /** * \brief Abstract integrator that should **exclusively** be used to trampoline * Python AD integrators for primal renderings @@ -202,6 +249,7 @@ MI_VARIANT class PyADIntegrator : public CppADIntegrator { MI_PY_EXPORT(Integrator) { MI_PY_IMPORT_TYPES() using PySamplingIntegrator = PySamplingIntegrator; + using PyAdjointIntegrator = PyAdjointIntegrator; using CppADIntegrator = CppADIntegrator; using PyADIntegrator = PyADIntegrator; @@ -257,7 +305,8 @@ MI_PY_EXPORT(Integrator) { PyADIntegrator>(m, "CppADIntegrator") .def(py::init()); - MI_PY_CLASS(AdjointIntegrator, Integrator) + MI_PY_TRAMPOLINE_CLASS(PyAdjointIntegrator, AdjointIntegrator, Integrator) + .def(py::init()) .def_method(AdjointIntegrator, sample, "scene"_a, "sensor"_a, "sampler"_a, "block"_a, "sample_scale"_a); }