Skip to content

Commit

Permalink
Change AOT from ExprVisitor to MixedModeVisitor (apache#8856)
Browse files Browse the repository at this point in the history
This should allow better scale-ability for AOT when targeting larger networks.
  • Loading branch information
Mousius authored and ylc committed Sep 29, 2021
1 parent 8db94c4 commit f7a41dd
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ using StorageMap =
* This is an on demand allocator for AOT. A new temporary
* (storage allocator identifier) is allocated for each operation.
*/
class AOTOnDemandAllocator : public ExprVisitor {
class AOTOnDemandAllocator : public MixedModeVisitor {
public:
// run the visitor on a function.
void Run(const Function& func) {
Expand Down Expand Up @@ -84,10 +84,7 @@ class AOTOnDemandAllocator : public ExprVisitor {
AssignReturnSid(GetRef<Expr>(op));
}

void VisitExpr_(const VarNode* op) final {
ExprVisitor::VisitExpr_(op);
AssignReturnSid(GetRef<Expr>(op));
}
void VisitExpr_(const VarNode* op) final { AssignReturnSid(GetRef<Expr>(op)); }

void VisitExpr_(const FunctionNode* op) final {
// do not recurse into sub function.
Expand Down Expand Up @@ -218,7 +215,7 @@ class AOTOnDemandAllocator : public ExprVisitor {
};

/*! \brief Code generator for AOT executor */
class AOTExecutorCodegen : public ExprVisitor {
class AOTExecutorCodegen : public MixedModeVisitor {
protected:
/*!
* \brief Utility function to allocate a DLTensor or TVMValue
Expand Down Expand Up @@ -437,7 +434,6 @@ class AOTExecutorCodegen : public ExprVisitor {
void VisitExpr_(const OpNode* op) override {
throw std::runtime_error("can not compile op in non-eta expanded form");
}
void VisitExpr_(const GlobalVarNode* op) override { throw std::runtime_error(""); }
void VisitExpr_(const IfNode* op) override { throw std::invalid_argument("if not supported"); }
void VisitExpr_(const FunctionNode* op) override {
ICHECK(op->GetAttr<String>(attr::kCompiler).defined())
Expand Down

0 comments on commit f7a41dd

Please sign in to comment.