Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change target string to Target object in the TE compiler and interpreter #8835

Merged
merged 14 commits into from
Aug 31, 2021
55 changes: 55 additions & 0 deletions include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <tvm/target/target_kind.h>

#include <string>
#include <unordered_map>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably remove this as it's not used in this file.

Copy link
Member

@jroesch jroesch Aug 31, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will send a follow up P that does this just for the sake of forward progress. Thanks!

#include <unordered_set>
#include <vector>

Expand Down Expand Up @@ -203,5 +204,59 @@ void CheckAndUpdateHostConsistency(Map<Integer, Target>* target, Target* host);
* \param host The Target typed object for target host to be updated
*/
void CheckAndUpdateHostConsistency(Map<Target, IRModule>* target, Target* host);

// TODO(@electriclilies): Move to somewhere in backend and add note about appropriate use
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey what about moving these methods temporarily to src/relay/backend/utils.h instead? Given these are only used in relay backend right now, I think it would be helpful to sort of prevent future developers to use them :-)

Copy link
Contributor Author

@electriclilies electriclilies Aug 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved them!


/*! \brief Target hash function */
struct TargetStrHash {
/*!
* \brief Calculate the hash code of a Target based on the string value of the Target
This will be removed when maps from Targets to IRModules are removed from the codebase.
* \param target The Target to hash
* \return String hash of the target
*/
size_t operator()(const Target& target) const {
return String::HashBytes(target->str().c_str(), target->str().size());
}
};

/*! \brief Target equality function based on the string value of Target
This will be removed when maps from Targets to IRModules are removed from the
codebase.*/
struct TargetStrEqual {
/*!
* \brief Check if the two Targets are equal
* \param target One Target
* \param other_target The other Target
* \return String equality of the targets
*/
const bool operator()(const Target& target, const Target& other_target) const {
TargetStrHash target_hash = TargetStrHash();
return target_hash(target) == target_hash(other_target);
}
};

/*!
* \brief Convert a Map<Target, IRModule> to std::unordered_map<Target, IRmodule, TargetStrHash,
* TargetStrEqual> Target equality is currently based on pointer equality, which is a problem since
* we have a lot of Map<Target, IRModule> in the codebase. This function converts the map to a
* version that is keyed based on string value of the Target instead. Note that once we remove
* Map<Target, IRModule>, this function will be removed.
* \param input_map The map to convert
* \return The converted map
*/
std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual>
TargetModuleMapToTargetStrModuleMap(Map<Target, IRModule> input_map);

/*!
* \brief Convert a std::unordered_map<Target, IRmodule, TargetStrHash, TargetStrEqual> to
* Map<Target, IRModule> This function is a helper that undoes TargetModuleMapToTargetStr. Note that
* once we remove Map<Target, IRModule>, this function will be removed.
* \param input_map The map to convert
* \return The converted map
*/
Map<Target, IRModule> TargetStrModuleMapToTargetModuleMap(
std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual> input_map);

} // namespace tvm
#endif // TVM_TARGET_TARGET_H_
9 changes: 4 additions & 5 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -669,11 +669,10 @@ class AOTExecutorCodegen : public ExprVisitor {
ret.lowered_funcs = lowered_module.per_target_module;
ret.external_mods = lowered_module.external_mods;

auto target_host_str = target_host_->str();
if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) {
ret.lowered_funcs[target_host_str]->Update(mod_run);
if (ret.lowered_funcs.find(target_host_) != ret.lowered_funcs.end()) {
ret.lowered_funcs[target_host_]->Update(mod_run);
} else {
ret.lowered_funcs.Set(target_host_str, mod_run);
ret.lowered_funcs.Set(target_host_, mod_run);
}

std::vector<String> input_var_names(input_vars_.size());
Expand Down Expand Up @@ -778,7 +777,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
return (*it).second.first;
}

Map<String, IRModule> get_irmodule() { return this->output_.lowered_funcs; }
Map<Target, IRModule> get_irmodule() { return this->output_.lowered_funcs; }

std::shared_ptr<AOTExecutorCodegen> codegen_;
LoweredOutput output_;
Expand Down
17 changes: 9 additions & 8 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ struct ExecutorCodegen {
return CallFunc<Array<tvm::runtime::Module>>("get_external_modules", nullptr);
}

Map<String, IRModule> GetIRModule() {
return CallFunc<Map<String, IRModule>>("get_irmodule", nullptr);
Map<Target, IRModule> GetIRModule() {
return CallFunc<Map<Target, IRModule>>("get_irmodule", nullptr);
}

runtime::Metadata GetMetadata() { return CallFunc<runtime::Metadata>("get_metadata"); }
Expand Down Expand Up @@ -491,8 +491,9 @@ class RelayBuildModule : public runtime::ModuleNode {
auto lowered_funcs = executor_codegen_->GetIRModule();

// No need to build for external functions.
if (lowered_funcs.find("ext_dev") != lowered_funcs.end()) {
lowered_funcs.Set("ext_dev", IRModule());
Target ext_dev("ext_dev");
if (lowered_funcs.find(ext_dev) != lowered_funcs.end()) {
lowered_funcs.Set(ext_dev, IRModule());
}

// Generate a placeholder function that attaches linked params as its arguments.
Expand All @@ -510,11 +511,11 @@ class RelayBuildModule : public runtime::ModuleNode {
DictAttrs attrs{dict};
auto prim = tir::PrimFunc(Array<tir::Var>(), tir::SeqStmt(Array<tir::Stmt>()), VoidType(),
Map<tir::Var, tir::Buffer>(), attrs);
if (lowered_funcs.find(target_host->str()) == lowered_funcs.end()) {
lowered_funcs.Set(target_host->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
if (lowered_funcs.find(target_host) == lowered_funcs.end()) {
lowered_funcs.Set(target_host, IRModule(Map<GlobalVar, BaseFunc>({})));
}
lowered_funcs[target_host->str()]->Add(
GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param), prim);
lowered_funcs[target_host]->Add(GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param),
prim);
}

// When there is no lowered_funcs due to reasons such as optimization.
Expand Down
29 changes: 17 additions & 12 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ namespace {
struct PairHash {
template <typename T1, typename T2>
std::size_t operator()(const std::pair<T1, T2>& k) const {
return std::hash<T1>()(k.first) ^ std::hash<T2>()(k.second);
return dmlc::HashCombine(std::hash<T1>()(k.first), std::hash<T2>()(k.second));
}
template <typename T2>
std::size_t operator()(const std::pair<Target, T2>& k) const {
return dmlc::HashCombine(ObjectHash()(k.first), std::hash<T2>()(k.second));
}
};

Expand Down Expand Up @@ -289,7 +293,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> {
public:
// TODO(mbs): Collapse mod and per_target_module once IRModule subsumes LoweredModule.
Interpreter(IRModule mod, Map<String, IRModule> per_target_module, Device device, Target target)
Interpreter(IRModule mod, Map<Target, IRModule> per_target_module, Device device, Target target)
: mod_(mod),
per_target_module_(per_target_module),
device_(device),
Expand Down Expand Up @@ -373,7 +377,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
*/
PackedFunc TIRToPackedFunc(const GlobalVar& tir_fn_var, const Array<GlobalVar>& all_tir_fn_vars,
Target target) {
std::pair<std::string, std::string> packed_func_key(target->str(), tir_fn_var->name_hint);
std::pair<Target, std::string> packed_func_key(target, tir_fn_var->name_hint);
auto packed_itr = compiled_packed_funcs_.find(packed_func_key);
if (packed_itr != compiled_packed_funcs_.end()) {
// Already compiled.
Expand All @@ -382,8 +386,10 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,

// Project out just the function(s) we need.
IRModule lowered_projected_mod;
auto mod_itr = per_target_module_.find(target->str());
ICHECK(mod_itr != per_target_module_.end())
std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual> per_target_module_std_map_ =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: don't append a _ for local vars since the convention is it indicates a member var.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

TargetModuleMapToTargetStrModuleMap(per_target_module_);
auto mod_itr = per_target_module_std_map_.find(target);
ICHECK(mod_itr != per_target_module_std_map_.end())
<< "No target module for target '" << target->str() << "'";
const IRModule& target_module = (*mod_itr).second;
for (const auto& var : all_tir_fn_vars) {
Expand All @@ -407,7 +413,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
PackedFunc packed_func = runtime_module.GetFunction(var->name_hint);
ICHECK(packed_func != nullptr) << "No packed function for global var '" << var->name_hint
<< "' in compiled module for target '" << target->str() << "'";
compiled_packed_funcs_.emplace(std::make_pair(target->str(), var->name_hint), packed_func);
compiled_packed_funcs_.emplace(std::make_pair(target, var->name_hint), packed_func);
}

// Return just what we need for this call.
Expand Down Expand Up @@ -874,11 +880,10 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
// Map from target key to lowered TIR functions derived from mod_.
// Note that primitives are implicitly executed on target_, while shape functions are implicitly
// executed on the default 'cpu' host. Thus this map has at most two entries.
Map<String, IRModule> per_target_module_;
Map<Target, IRModule> per_target_module_;
// Cached packed functions for the primitives and shape functions, keyed by target and
// global var name.
std::unordered_map<std::pair<std::string, std::string>, PackedFunc, PairHash>
compiled_packed_funcs_;
std::unordered_map<std::pair<Target, std::string>, PackedFunc, PairHash> compiled_packed_funcs_;
// Unique device on which primitives (but not shape functions) will be executed.
// (For simplicity we only run the interpreter on a single device.)
Device device_;
Expand All @@ -895,7 +900,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
* rewritten \p mod and target-specific modules containing bindings for all TIR primitive
* functions needed by the rewritten module.
*/
std::pair<IRModule, Map<String, IRModule>> Prepare(IRModule mod, Device device, Target target) {
std::pair<IRModule, Map<Target, IRModule>> Prepare(IRModule mod, Device device, Target target) {
// Run minimal transforms on module to establish invariants needed by interpreter.
transform::Sequential seq({transform::SimplifyInference(),
// FuseOps will mark wrapped calls to prim-ops with the 'Primitive'
Expand Down Expand Up @@ -1014,7 +1019,7 @@ TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, De
// and can just eval it directly.
expr_to_eval = expr;
}
std::pair<IRModule, Map<String, IRModule>> main_and_lowered =
std::pair<IRModule, Map<Target, IRModule>> main_and_lowered =
Prepare(mod_with_expr, device, target);
std::shared_ptr<Interpreter> intrp = std::make_shared<Interpreter>(
/*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device,
Expand Down Expand Up @@ -1057,7 +1062,7 @@ ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
std::unordered_set<String> import_set, Device device, Target target) {
std::pair<IRModule, GlobalVar> mod_and_global =
IRModule::FromExprInContext(expr, /*global_funcs=*/{}, type_definitions, import_set);
std::pair<IRModule, Map<String, IRModule>> main_and_lowered =
std::pair<IRModule, Map<Target, IRModule>> main_and_lowered =
Prepare(mod_and_global.first, device, target);
Interpreter intrp(
/*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device,
Expand Down
24 changes: 12 additions & 12 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,32 +85,32 @@ class TECompilerImpl : public TECompilerNode {
return LowerShapeFuncInternal(key)->cached_func;
}

Map<String, IRModule> GetLoweredFunctions() {
Map<String, IRModule> lowered_functions;
Map<Target, IRModule> GetLoweredFunctions() {
std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual> lowered_functions;
for (const auto& it : cache_) {
auto source_func = it.first;
auto lowered_func = it.second;
auto target = source_func->target;

if (!lowered_functions.count(target->str())) {
lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
if (!lowered_functions.count(target)) {
lowered_functions[target] = IRModule(Map<GlobalVar, BaseFunc>({}));
}

lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
lowered_functions[target]->Update(lowered_func->cached_func->funcs);
}

for (const auto& it : shape_func_cache_) {
auto source_func = it.first;
auto lowered_func = it.second;
auto target = source_func->target;

if (!lowered_functions.count(target->str())) {
lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
if (!lowered_functions.count(target)) {
lowered_functions[target] = IRModule(Map<GlobalVar, BaseFunc>({}));
}

lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
lowered_functions[target]->Update(lowered_func->cached_func->funcs);
}
return lowered_functions;
return TargetStrModuleMapToTargetModuleMap(lowered_functions);
}

Array<tvm::runtime::Module> LowerExternalFunctions() {
Expand Down Expand Up @@ -884,7 +884,7 @@ IRModule LoweredModuleToIRModule(LoweredModule mod) {

// Annotate the per-target functions with their target and add them to the unified module
for (const auto& kv : mod.per_target_module) {
const String target = kv.first;
const Target target = kv.first;
const IRModule target_module = kv.second;

// Right now, per-target functions are TIR functions, which don't have type definitions, so
Expand Down Expand Up @@ -926,15 +926,15 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) {
main_mod->AddTypeDef(kv.first, kv.second);
}

Map<String, IRModule> per_target_modules;
Map<Target, IRModule> per_target_modules;
for (const auto& kv : mod->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;
if (func->IsInstance<relay::FunctionNode>()) {
main_mod->Add(var, func);
} else if (func->IsInstance<tir::PrimFuncNode>()) {
// Extract target
Optional<String> target = func->GetAttr<String>(tvm::attr::kTarget);
Optional<Target> target = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target) << "Target should be set at this point";

// Put the function in per_target_modules
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/te_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class TECompilerNode : public Object {
virtual CachedFunc Lower(const CCacheKey& key, const String mod_name) = 0;

/* Return all functions which have been lowered by the compiler, keyed by target. */
virtual Map<String, IRModule> GetLoweredFunctions() = 0;
virtual Map<Target, IRModule> GetLoweredFunctions() = 0;

/*!
* \brief Just in time compile to get a PackedFunc.
Expand Down Expand Up @@ -144,7 +144,7 @@ struct LoweredModule {
/*! \brief The module which contains the Relay code. */
IRModule main_module;
/*! \brief The module which contains per target code. */
Map<String, IRModule> per_target_module;
Map<Target, IRModule> per_target_module;
/*! \brief The external runtime modules which must be combined with the lowered code. */
Array<tvm::runtime::Module> external_mods;
// TODO(@electriclilies): THis might need to become a map
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ int64_t CalculateRelayExprSizeBytes(const Type& expr_type);
*/
struct LoweredOutput {
std::string graph_json;
Map<String, IRModule> lowered_funcs;
Map<Target, IRModule> lowered_funcs;
Array<tvm::runtime::Module> external_mods;
Map<String, FunctionInfo> function_metadata;
std::unordered_map<std::string, std::pair<int, const tvm::runtime::NDArray>> params;
Expand Down
19 changes: 19 additions & 0 deletions src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,25 @@ std::unordered_map<String, ObjectRef> TargetInternal::QueryDevice(int device_id,
return output;
}

// Helper to convert the tvm::Map to a std::unordered_map
std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual>
TargetModuleMapToTargetStrModuleMap(Map<Target, IRModule> input_map) {
std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual> std_map;
for (auto kv : input_map) {
std_map[kv.first] = kv.second;
}
return std_map;
}

Map<Target, IRModule> TargetStrModuleMapToTargetModuleMap(
std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual> input_map) {
Map<Target, IRModule> tvm_map;
for (auto kv : input_map) {
tvm_map.Set(kv.first, kv.second);
}
return tvm_map;
}

/********** Registry **********/

TVM_REGISTER_GLOBAL("target.Target").set_body(TargetInternal::ConstructorDispatcher);
Expand Down