Skip to content

Commit

Permalink
Parallelize load_dict routine
Browse files Browse the repository at this point in the history
  • Loading branch information
Speierers committed Mar 21, 2023
1 parent c4a8b31 commit bb672ed
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 93 deletions.
23 changes: 23 additions & 0 deletions include/mitsuba/core/xml.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,26 @@
NAMESPACE_BEGIN(mitsuba)
NAMESPACE_BEGIN(xml)

struct ScopedSetJITScope {
ScopedSetJITScope(uint32_t backend, uint32_t scope) : backend(backend) {
#if defined(MI_ENABLE_LLVM) || defined(MI_ENABLE_CUDA)
if (backend) {
backup = jit_scope((JitBackend) backend);
jit_set_scope((JitBackend) backend, scope);
}
#endif
}

~ScopedSetJITScope() {
#if defined(MI_ENABLE_LLVM) || defined(MI_ENABLE_CUDA)
if (backend)
jit_set_scope((JitBackend) backend, backup);
#endif
}

uint32_t backend, backup;
};

/// Used to pass key=value pairs to the parser
using ParameterList = std::vector<std::tuple<std::string, std::string, bool>>;

Expand All @@ -29,6 +49,9 @@ using ParameterList = std::vector<std::tuple<std::string, std::string, bool>>;
* \param update_scene
* When Mitsuba updates scene to a newer version, should the
* updated XML file be written back to disk?
*
* \param parallel
* Whether the loading should be executed on multiple threads in parallel
*/
extern MI_EXPORT_LIB std::vector<ref<Object>> load_file(
const fs::path &path,
Expand Down
5 changes: 4 additions & 1 deletion include/mitsuba/python/docstr.h
Original file line number Diff line number Diff line change
Expand Up @@ -11743,7 +11743,10 @@ Parameter ``variant``:
Parameter ``update_scene``:
When Mitsuba updates scene to a newer version, should the updated
XML file be written back to disk?)doc";
XML file be written back to disk?
Parameter ``parallel``:
Whether the loading should be executed on multiple threads in parallel)doc";

static const char *__doc_mitsuba_xml_load_string = R"doc(Load a Mitsuba scene from an XML string)doc";

Expand Down
240 changes: 169 additions & 71 deletions src/core/python/xml_v.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,38 @@
#include <mitsuba/core/spectrum.h>
#include <mitsuba/core/transform.h>
#include <mitsuba/python/python.h>
#include <nanothread/nanothread.h>
#include <map>

using Caster = py::object(*)(mitsuba::Object *);
extern Caster cast_object;

struct DictInstance {
Properties props;
ref<Object> object = nullptr;
uint32_t scope;
std::vector<std::pair<std::string, std::string>> dependencies;
};

struct DictParseContext {
ThreadEnvironment env;
std::map<std::string, DictInstance> instances;
std::map<std::string, std::string> aliases;
bool parallel;
};

// Forward declaration
template <typename Float, typename Spectrum>
std::vector<ref<Object>> load_dict(
const std::string &dict_key,
const py::dict &dict,
std::map<std::string, ref<Object>> &instances
void parse_dictionary(
DictParseContext &ctx,
const std::string path,
const py::dict &dict
);
template <typename Float, typename Spectrum>
Task * instantiate_node(
DictParseContext &ctx,
const std::string path,
std::unordered_map<std::string, Task *> &task_map
);

/// Shorthand notation for accessing the MI_VARIANT string
Expand Down Expand Up @@ -87,21 +108,24 @@ MI_PY_EXPORT(xml) {

m.def(
"load_dict",
[](const py::dict dict) {
std::map<std::string, ref<Object>> instances;
std::vector<ref<Object>> objects =
load_dict<Float, Spectrum>("", dict, instances);

py::object out = single_object_or_list(objects);

return out;
[](const py::dict dict, bool parallel) {
DictParseContext ctx;
ctx.parallel = parallel;
parse_dictionary<Float, Spectrum>(ctx, "__root__", dict);
std::unordered_map<std::string, Task*> task_map;
instantiate_node<Float, Spectrum>(ctx, "__root__", task_map);
auto objects = mitsuba::xml::detail::expand_node(ctx.instances["__root__"].object);
return single_object_or_list(objects);
},
"dict"_a,
"dict"_a, "parallel"_a=true,
R"doc(Load a Mitsuba scene or object from an Python dictionary
Parameter ``dict``:
Python dictionary containing the object description
Parameter ``parallel``:
Whether the loading should be executed on multiple threads in parallel
)doc");

m.def(
Expand Down Expand Up @@ -210,19 +234,22 @@ ref<Object> create_texture_from(const py::dict &dict, bool within_emitter) {
}

template <typename Float, typename Spectrum>
std::vector<ref<Object>> load_dict(const std::string &dict_key,
const py::dict &dict,
std::map<std::string,
ref<Object>> &instances) {
void parse_dictionary(DictParseContext &ctx,
const std::string path,
const py::dict &dict) {
MI_IMPORT_CORE_TYPES()
using ScalarArray3f = dr::Array<ScalarFloat, 3>;

std::string type = get_type(dict);

if (type == "spectrum" || type == "rgb")
return { create_texture_from<Float, Spectrum>(dict, false) };
auto &inst = ctx.instances[path];

std::string type = get_type(dict);
bool is_scene = (type == "scene");
bool is_root = string::starts_with(path, "__root__");

if (type == "spectrum" || type == "rgb") {
inst.object = create_texture_from<Float, Spectrum>(dict, false);
return;
}

const Class *class_;
if (is_scene)
Expand All @@ -231,7 +258,11 @@ std::vector<ref<Object>> load_dict(const std::string &dict_key,
class_ = PluginManager::instance()->get_plugin_class(type, GET_VARIANT())->parent();

bool within_emitter = (!is_scene && class_->alias() == "emitter");
Properties props(type);

Properties &props = inst.props;
props.set_plugin_name(type);

std::string id;

for (auto& [k, value] : dict) {
std::string key = k.template cast<std::string>();
Expand All @@ -240,7 +271,7 @@ std::vector<ref<Object>> load_dict(const std::string &dict_key,
continue;

if (key == "id") {
props.set_id(value.template cast<std::string>());
id = value.template cast<std::string>();
continue;
}

Expand All @@ -252,7 +283,7 @@ std::vector<ref<Object>> load_dict(const std::string &dict_key,
SET_PROPS(ScalarArray3f, ScalarArray3f, set_array3f);
SET_PROPS(ScalarTransform4f, ScalarTransform4f, set_transform);

// Load nested dictionary
// Parse nested dictionary
if (py::isinstance<py::dict>(value)) {
py::dict dict2 = value.template cast<py::dict>();
std::string type2 = get_type(dict2);
Expand All @@ -268,51 +299,27 @@ std::vector<ref<Object>> load_dict(const std::string &dict_key,
if (is_scene)
Throw("Reference found at the scene level: %s", key);

for (auto& [k2, value2] : value.template cast<py::dict>()) {
std::string key2 = k2.template cast<std::string>();
for (auto& kv2 : value.template cast<py::dict>()) {
std::string key2 = kv2.first.template cast<std::string>();
if (key2 == "id") {
std::string id = value2.template cast<std::string>();
if (instances.count(id) == 1)
expand_and_set_object(props, key, instances[id]);
std::string id2 = kv2.second.template cast<std::string>();
std::string path2;
if (ctx.aliases.count(id2) == 1)
path2 = ctx.aliases[id2];
else
Throw("Referenced id \"%s\" not found: %s", id, key);
} else if (key2 != "type") {
path2 = id2;
if (ctx.instances.count(path2) != 1)
Throw("Referenced id \"%s\" not found: %s", path2, path);
inst.dependencies.push_back({key, path2});
} else if (key2 != "type") {
Throw("Unexpected key in ref dictionary: %s", key2);
}
}
continue;
} else {
std::string path2 = is_root ? key : path + "." + key;
inst.dependencies.push_back({key, path2});
parse_dictionary<Float, Spectrum>(ctx, path2, dict2);
}

// Load the dictionary recursively
std::vector<ref<Object>> objects = load_dict<Float, Spectrum>(key, dict2, instances);
size_t n_objects = objects.size();
int ctr = 0;

for (auto &obj : objects) {
if (n_objects > 1) {
props.set_object(key + "_" + std::to_string(ctr++), obj);
} else {
props.set_object(key, obj);
}

// Add instanced object to the instance map for later references
if (is_scene) {
// An object can be referenced using its key
if (instances.count(key) != 0)
Throw("%s has duplicate id: %s", key, key);
instances[key] = obj;

// An object can also be referenced using its "id" if it has
// one
std::string id = obj->id();
if (!id.empty() && id != key) {
if (instances.count(id) != 0)
Throw("%s has duplicate id: %s", key, id);
instances[id] = obj;
}
}
}

continue;
}

Expand All @@ -334,17 +341,108 @@ std::vector<ref<Object>> load_dict(const std::string &dict_key,
Throw("Unkown value type: %s", value.get_type());
}

// Use the dict key as id (if available) if no id was already set
if (props.id().empty() && !dict_key.empty())
props.set_id(dict_key);
// Set object id based on path in dictionary if no id is provided
props.set_id(id.empty() ? string::tokenize(path, ".").back() : id);

if constexpr (dr::is_jit_v<Float>) {
if (ctx.parallel) {
jit_new_scope(dr::backend_v<Float>);
inst.scope = jit_scope(dr::backend_v<Float>);
}
}

if (!id.empty()) {
if (ctx.aliases.count(id) != 0)
Throw("%s has duplicate id: %s", path, id);
ctx.aliases[id] = path;
}
}

// Construct the object with the parsed Properties
auto obj = PluginManager::instance()->create_object(props, class_);
template <typename Float, typename Spectrum>
Task *instantiate_node(DictParseContext &ctx,
std::string path,
std::unordered_map<std::string, Task *> &task_map) {
if (task_map.find(path) != task_map.end())
return task_map.find(path)->second;

auto &inst = ctx.instances[path];
uint32_t scope = inst.scope;
uint32_t backend = (uint32_t) dr::backend_v<Float>;
bool is_root = path == "__root__";

// Early exit if the object was already instantiated
if (inst.object)
return nullptr;

std::vector<Task *> deps;
for (auto &[key2, path2] : inst.dependencies) {
if (task_map.find(path2) == task_map.end()) {
Task *task = instantiate_node<Float, Spectrum>(ctx, path2, task_map);
task_map.insert({path2, task});
}
deps.push_back(task_map.find(path2)->second);
}

if (!props.unqueried().empty())
Throw("Unreferenced property \"%s\" in plugin of type \"%s\"!", props.unqueried()[0], type);
auto instantiate = [&ctx, path, scope, backend]() {
ScopedSetThreadEnvironment set_env(ctx.env);
mitsuba::xml::ScopedSetJITScope set_scope(ctx.parallel ? backend : 0u, scope);

auto &inst = ctx.instances[path];
Properties props = inst.props;
std::string type = props.plugin_name();

const Class *class_;
if (type == "scene")
class_ = Class::for_name("Scene", GET_VARIANT());
else
class_ = PluginManager::instance()->get_plugin_class(type, GET_VARIANT())->parent();

for (auto &[key2, path2] : inst.dependencies) {
if (ctx.instances.count(path2) == 1) {
auto obj2 = ctx.instances[path2].object;
if (obj2)
expand_and_set_object(props, key2, obj2);
else
Throw("Dependence hasn't been instantiated yet: %s, %s -> %s", path, path2, key2);
} else {
Throw("Dependence path \"%s\" not found: %s", path2, path);
}
}

return mitsuba::xml::detail::expand_node(obj);
// Construct the object with the parsed Properties
inst.object = PluginManager::instance()->create_object(props, class_);

if (!props.unqueried().empty())
Throw("Unreferenced property \"%s\" in plugin of type \"%s\"!", props.unqueried()[0], type);
};

// Top node always instantiated on the main thread
if (is_root) {
std::exception_ptr eptr;
for (auto& task : deps) {
try {
py::gil_scoped_release gil_release{};
task_wait(task);
} catch (...) {
if (!eptr)
eptr = std::current_exception();
}
}
for (auto& kv : task_map)
task_release(kv.second);
if (eptr)
std::rethrow_exception(eptr);
instantiate();
return nullptr;
} else {
if (ctx.parallel) {
// Instantiate object asynchronously
return dr::do_async(instantiate, deps.data(), deps.size());
} else {
instantiate();
return nullptr;
}
}
}

#undef SET_PROPS
21 changes: 0 additions & 21 deletions src/core/xml.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -985,27 +985,6 @@ static std::pair<std::string, std::string> parse_xml(XMLSource &src, XMLParseCon
return std::make_pair("", "");
}

struct ScopedSetJITScope {
ScopedSetJITScope(uint32_t backend, uint32_t scope) : backend(backend) {
#if defined(MI_ENABLE_LLVM) || defined(MI_ENABLE_CUDA)
if (backend) {
backup = jit_scope((JitBackend) backend);
jit_set_scope((JitBackend) backend, scope);
}
#endif
}

~ScopedSetJITScope() {
#if defined(MI_ENABLE_LLVM) || defined(MI_ENABLE_CUDA)
if (backend)
jit_set_scope((JitBackend) backend, backup);
#endif
}

uint32_t backend, backup;
};


static std::string init_xml_parse_context_from_file(XMLParseContext &ctx,
const fs::path &filename_,
ParameterList param,
Expand Down

0 comments on commit bb672ed

Please sign in to comment.