diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 5f2c4de3152a..2b644230dcde 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -377,7 +377,7 @@ class TupleType : public Type { inline Type VoidType() { return TupleType::Empty(); } /*! - * \brief Check whether the tyep represents void. + * \brief Check whether the type represents void. * \return The check result. */ inline bool IsVoidType(const Type& type) { diff --git a/include/tvm/relax/ir_builder.h b/include/tvm/relax/ir_builder.h new file mode 100644 index 000000000000..352882116916 --- /dev/null +++ b/include/tvm/relax/ir_builder.h @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/ir_builder.h + * \brief The utility for constructing Relax AST. + */ +#ifndef TVM_RELAX_IR_BUILDER_H_ +#define TVM_RELAX_IR_BUILDER_H_ + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +using relay::Call; + +class IRBuilder; + +/*! + * \brief The state of Relax function node being built. + */ +struct RelaxFunction { + /*! \brief The function name. */ + Optional func_name = NullOpt; + /*! \brief The function parameters. */ + Array params; + /*! \brief The bindings in the function. */ + std::vector bindings; + /*! \brief The binding blocks in the function. */ + std::vector binding_blocks; + /*! \brief The return of the function. */ + Expr ret = Tuple(); + /*! \brief The FunctionNode being built. */ + Function func; +}; + +/*! + * \brief A builder that provides APIs to build Relax AST. + */ +class IRBuilderNode : public Object { + public: + /*! + * \brief Fill the function name and parameters. + */ + void FillFuncNameParam(const Array& params, const std::string& func_name); + /*! + * \brief Build a function node. + */ + void BuildFunction(); + /*! + * \brief Build a binding block. + */ + void BuildBlock(); + /*! + * \brief Emit a call node. + * \param call The CallNode to be emitted. + * \return The variable being created and binded to \p call. + */ + Var Emit(const Call& call); + /*! + * \brief Generate an output for the current dataflow block or function. + * \param output The output variable of the block/function. + * \return The variable being binded to \p ouput. + */ + Var EmitOutput(const Expr& output); + /*! + * \brief Get the function being built. + */ + Function Get(); + /*! + * \brief Get binding blocks being built. + */ + std::vector GetBlocks(); + /*! + * \brief Create a IRBuilder. + * \return The created IRBuilder. + */ + TVM_DLL static IRBuilder Create(); + + void VisitAttrs(AttrVisitor* v) {} + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.IRBuilder"; + TVM_DECLARE_FINAL_OBJECT_INFO(IRBuilderNode, Object); + + private: + /*! \brief The state of the function currently being built. */ + RelaxFunction func; + /*! \brief A flag tracking if currently inside a dataflow block or not. */ + bool is_dataflow = false; + /*! \brief A global variable counter for naming global variables. */ + int global_var_counter = 0; + /*! \brief A dataflow variable counter for naming dataflow variables. */ + int dataflow_var_counter = 0; +}; + +class IRBuilder : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(IRBuilder, ObjectRef, IRBuilderNode); +}; + +/*! \brief Auxiliary scope for building Relax function node, + * similar to python's with syntax. + * + * \code + * { + * With scope(ir_builder); + * // build function node. + * } + */ +class FunctionScopeNode : public Object { + public: + IRBuilder ir_builder; + void VisitAttrs(AttrVisitor* v) { v->Visit("ir_builder", &ir_builder); } + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.FunctionScope"; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionScopeNode, Object); +}; + +class FunctionScope : public ObjectRef { + public: + TVM_DLL FunctionScope(IRBuilder ib); + TVM_DEFINE_OBJECT_REF_METHODS(FunctionScope, ObjectRef, FunctionScopeNode); + class Internal; + + private: + // Classes to get the Python `with` like syntax. + friend class Internal; + friend class With; + // The entry of a function scope. + TVM_DLL void EnterWithScope(); + // The exit of a function scope. + TVM_DLL void ExitWithScope(); +}; + +/*! \brief Auxiliary scope for building Relax dataflow block, + * similar to python's with syntax. + * + * \code + * { + * With scope(ir_builder); + * // build dataflow block. + * } + */ +class DataflowScopeNode : public Object { + public: + IRBuilder ir_builder; + void VisitAttrs(AttrVisitor* v) { v->Visit("ir_builder", &ir_builder); } + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.DataflowScope"; + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowScopeNode, Object); +}; + +class DataflowScope : public ObjectRef { + public: + TVM_DLL DataflowScope(IRBuilder ib); + TVM_DEFINE_OBJECT_REF_METHODS(DataflowScope, ObjectRef, DataflowScopeNode); + class Internal; + + private: + // Classes to get the Python `with` like syntax. + friend class Internal; + friend class With; + // The entry of a dataflow scope. + TVM_DLL void EnterWithScope(); + // The exit of a dataflow scope. + TVM_DLL void ExitWithScope(); +}; + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_IR_BUILDER_H_ diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h index f073a7522855..9967c29a5f5e 100644 --- a/include/tvm/relax/type.h +++ b/include/tvm/relax/type.h @@ -39,9 +39,7 @@ namespace relax { class ShapeTypeNode : public TypeNode { public: - - void VisitAttrs(tvm::AttrVisitor* v) { - } + void VisitAttrs(tvm::AttrVisitor* v) {} bool SEqualReduce(const ShapeTypeNode* other, SEqualReducer equal) const { return true; @@ -64,10 +62,9 @@ class ShapeType : public Type { const ShapeTypeNode* get() const { return operator->(); } - using ContainerType = ShapeTypeNode; + using ContainerType = ShapeTypeNode; }; - class DynTensorTypeNode : public BaseTensorTypeNode { public: /*! @@ -92,6 +89,10 @@ class DynTensorTypeNode : public BaseTensorTypeNode { hash_reduce(dtype); } + inline bool IsUnknownRank() const { return rank == -1; } + + inline bool IsUnknownDtype() const { return dtype.is_void(); } + static constexpr const char* _type_key = "relax.DynTensorType"; TVM_DECLARE_FINAL_OBJECT_INFO(DynTensorTypeNode, BaseTensorTypeNode); }; diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index cf2f4c0c751a..b81df5c88331 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -20,6 +20,8 @@ from . import ty from . import vm from . import op +from . import ir_builder +from . import op # Expr @@ -56,3 +58,6 @@ # Operator from .op.base import call_dps + +# IRBuilder +IRBuilder = ir_builder.IRBuilder diff --git a/python/tvm/relax/ir_builder.py b/python/tvm/relax/ir_builder.py new file mode 100644 index 000000000000..426b1ccabe33 --- /dev/null +++ b/python/tvm/relax/ir_builder.py @@ -0,0 +1,171 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Developer API of constructing Relax AST.""" +from typing import List, Optional, Union, Dict +from tvm.relay.expr import Tuple +from tvm.runtime import Object +from .expr import * +from tvm._ffi.base import _LIB, check_call +from . import _ffi_api + + +@tvm._ffi.register_object("relax.FunctionScope") +class FunctionScope(Object): + """Auxiliary scope for function""" + + def __init__(self, irbuilder): + self.__init_handle_by_constructor__(_ffi_api.CreateFunctionScope, irbuilder) + + def __enter__(self): + return self + + def __exit__(self, ptype, value, trace): + _ffi_api.ExitFunctionScope(self) + + +@tvm._ffi.register_object("relax.DataflowScope") +class DataflowScope(Object): + """Auxiliary scope for Dataflow block""" + + def __init__(self, irbuilder): + self.__init_handle_by_constructor__(_ffi_api.CreateDataflowScope, irbuilder) + + def __enter__(self): + _ffi_api.EnterDataflowScope(self) + + def __exit__(self, ptype, value, trace): + _ffi_api.ExitDataflowScope(self) + + +@tvm._ffi.register_object("relax.IRBuilder") +class IRBuilder(Object): + """A builder to build Relax IR for testing and dev. + + Examples + -------- + .. code-block:: python + + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + dtype0 = rx.DynTensorType(rank=2, dtype="float16") + dtype1 = rx.DynTensorType(rank=1, dtype="float16") + x = rx.Var("x", [m, n], dtype0) + y = rx.Var("y", [n], dtype1) + ib = rx.IRBuilder() + with ib.function([x, y], "func"): + with ib.dataflow() as df: + lv0 = ib.emit(rx.add(x, y)) + lv1 = ib.emit(rx.multiply(lv0, y)) + gv0 = ib.emit_output(lv1) + ib.emit_output(gv0) + func = ib.get() + """ + + def __init__(self): + self.__init_handle_by_constructor__(_ffi_api.IRBuilderCreate) + + def function(self, + params: Optional[Union[Var, Tuple, List[Var]]] = None, + name: Optional[str] = "") -> FunctionScope: + """Annotate a Relax function. + + Parameters + ---------- + params : tvm.relax.Var | Tuple | List[tvm.relax.Var], optional + The parameters of the function. + + name : str, optional + The name of the function. If provided, the function is global, otherwise local. + + Returns + ------- + ret: FunctionScope + A FunctionScope for building a Relax function node. + """ + if not params: + params = [] + if not isinstance(params, (list, tuple)): + params = [params] + + _ffi_api.IRBuilderFillFuncNameParam(self, params, name) + return FunctionScope(self) + + def dataflow(self) -> DataflowScope: + """Annotate a Relax dataflow block. + + Returns + ------- + ret: DataflowScope + A DataflowScope for building a Relax dataflow block. + """ + return DataflowScope(self) + + def emit(self, + call: relay.Call) -> Var: + """Emit a call node. + This infers the shape and type of the CallNode, create a variable, + and bind the CallNode to the variable. + + Parameters + ---------- + call : tvm.relay.Call + The call node to be emitted. + + Returns + ------- + ret : tvm.relax.Var + A newly created variable that gets binded to the call code. + """ + return _ffi_api.IRBuilderEmit(self, call) + + def emit_output(self, + output: Union[Expr, Tuple, List[Expr]]) -> None: + """Emit output for the current dataflow block or function. + + Parameters + ---------- + output : Expr | Tuple | List[Expr] + The output of the current block/function. + + Returns + ------- + ret : tvm.relax.Var + The return variable which gets binded to the output. + """ + if isinstance(output, (list, tuple)): + output = Tuple(output) + return _ffi_api.IRBuilderEmitOutput(self, output) + + def get(self) -> Function: + """Return the function being built. + + Returns + ------- + ret : tvm.relax.Function + A Relax function node being built. + """ + return _ffi_api.IRBuilderGet(self) + + def get_blocks(self) -> List[BindingBlock]: + """Return the binding blocks being built. + + Returns + ------- + ret : List[tvm.relax.BindingBlock] + A list of binding blocks being built. + """ + return _ffi_api.IRBuilderGetBlocks(self) diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 32d469f0400b..03f8b8858a4a 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -19,3 +19,4 @@ # Operators from .base import * +from .tensor import * diff --git a/python/tvm/relax/op/tensor.py b/python/tvm/relax/op/tensor.py new file mode 100644 index 000000000000..0e5854bdb586 --- /dev/null +++ b/python/tvm/relax/op/tensor.py @@ -0,0 +1,12 @@ +from . import _ffi_api +from ..expr import Expr + + +def add(lhs: Expr, + rhs: Expr) -> Expr: + return _ffi_api.add(lhs, rhs) + + +def multiply(lhs: Expr, + rhs: Expr) -> Expr: + return _ffi_api.multiply(lhs, rhs) diff --git a/src/relax/expr.cc b/src/relax/expr.cc index 7a4fa8a30826..42f84adf6cda 100644 --- a/src/relax/expr.cc +++ b/src/relax/expr.cc @@ -62,6 +62,9 @@ Var::Var(Id vid, n->vid = std::move(vid); n->shape_ = std::move(shape_annotation); n->type_annotation = std::move(type_annotation); + if (n->type_annotation) { + n->checked_type_ = n->type_annotation.value(); + } n->span = std::move(span); data_ = std::move(n); } diff --git a/src/relax/ir_builder.cc b/src/relax/ir_builder.cc new file mode 100644 index 000000000000..e5c09ccebd1d --- /dev/null +++ b/src/relax/ir_builder.cc @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir_builder.cc + */ + +#include +#include + +namespace tvm { +namespace relax { + +TVM_REGISTER_NODE_TYPE(IRBuilderNode); +TVM_REGISTER_NODE_TYPE(FunctionScopeNode); +TVM_REGISTER_NODE_TYPE(DataflowScopeNode); + +IRBuilder IRBuilderNode::Create() { + IRBuilder ret(make_object()); + return ret; +} + +void IRBuilderNode::FillFuncNameParam(const Array& params, const std::string& func_name) { + if (!func_name.empty()) { + this->func.func_name = GlobalVar(func_name); + } + + this->func.params = params; +} + +void IRBuilderNode::BuildFunction() { + SeqExpr seq = SeqExpr(this->func.binding_blocks, this->func.ret); + this->func.func = Function(this->func.func_name, this->func.params, seq, {}); + this->global_var_counter = 0; +} + +void IRBuilderNode::BuildBlock() { + if (!this->func.bindings.empty()) { + if (is_dataflow) { + this->func.binding_blocks.emplace_back(DataflowBlock(this->func.bindings)); + } else { + this->func.binding_blocks.emplace_back(BindingBlock(this->func.bindings)); + } + this->func.bindings.clear(); + } + this->dataflow_var_counter = 0; + this->is_dataflow = !this->is_dataflow; +} + +Var IRBuilderNode::Emit(const Call& call) { + Var var; + if (is_dataflow) { + var = DataflowVar(Id("lv" + std::to_string(dataflow_var_counter++)), NullOpt, NullOpt); + } else { + var = Var(Id("gv" + std::to_string(global_var_counter++)), NullOpt, NullOpt); + } + + this->func.bindings.emplace_back(VarBinding(var, call)); + return var; +} + +Var IRBuilderNode::EmitOutput(const Expr& output) { + Var ret; + if (is_dataflow) { + ret = Var(Id("gv" + std::to_string(global_var_counter++)), NullOpt, NullOpt); + ret->shape_ = output->shape_; + ret->checked_type_ = output->checked_type_; + this->func.bindings.emplace_back(VarBinding(ret, output)); + } else { + this->func.ret = output; + } + return ret; +} + +Function IRBuilderNode::Get() { return this->func.func; } + +std::vector IRBuilderNode::GetBlocks() { return this->func.binding_blocks; } + +class FunctionScope::Internal { + public: + static void ExitScope(FunctionScope scope) { scope.ExitWithScope(); } +}; + +FunctionScope::FunctionScope(IRBuilder ib) { + ObjectPtr n = make_object(); + n->ir_builder = std::move(ib); + data_ = std::move(n); +} + +void FunctionScope::ExitWithScope() { + this->get()->ir_builder->BuildBlock(); + this->get()->ir_builder->BuildFunction(); +} + +class DataflowScope::Internal { + public: + static void EnterScope(DataflowScope scope) { scope.EnterWithScope(); } + + static void ExitScope(DataflowScope scope) { scope.ExitWithScope(); } +}; + +DataflowScope::DataflowScope(IRBuilder ib) { + ObjectPtr n = make_object(); + n->ir_builder = std::move(ib); + data_ = std::move(n); +} + +void DataflowScope::EnterWithScope() { + this->get()->ir_builder->BuildBlock(); +} + +void DataflowScope::ExitWithScope() { + this->get()->ir_builder->BuildBlock(); +} + +TVM_REGISTER_GLOBAL("relax.IRBuilderCreate").set_body_typed(IRBuilderNode::Create); + +TVM_REGISTER_GLOBAL("relax.IRBuilderFillFuncNameParam") +.set_body_typed([](IRBuilder builder, const Array& params, const std::string& func_name) { + return builder->FillFuncNameParam(params, func_name); +}); + +TVM_REGISTER_GLOBAL("relax.IRBuilderBuildFunction").set_body_typed([](IRBuilder builder) { + return builder->BuildFunction(); +}); + +TVM_REGISTER_GLOBAL("relax.IRBuilderEmit").set_body_typed([](IRBuilder builder, const Call& call) { + return builder->Emit(call); +}); + +TVM_REGISTER_GLOBAL("relax.IRBuilderEmitOutput") +.set_body_typed([](IRBuilder builder, const Expr& output) { + return builder->EmitOutput(output); +}); + +TVM_REGISTER_GLOBAL("relax.IRBuilderGet").set_body_typed([](IRBuilder builder) { + return builder->Get(); +}); + +TVM_REGISTER_GLOBAL("relax.IRBuilderGetBlocks").set_body_typed([](IRBuilder builder) { + return Array(builder->GetBlocks()); +}); + +TVM_REGISTER_GLOBAL("relax.CreateFunctionScope").set_body_typed([](IRBuilder ib) { + return FunctionScope(ib); +}); + +TVM_REGISTER_GLOBAL("relax.ExitFunctionScope").set_body_typed(FunctionScope::Internal::ExitScope); + +TVM_REGISTER_GLOBAL("relax.CreateDataflowScope").set_body_typed([](IRBuilder ib) { + return DataflowScope(ib); +}); + +TVM_REGISTER_GLOBAL("relax.EnterDataflowScope").set_body_typed(DataflowScope::Internal::EnterScope); + +TVM_REGISTER_GLOBAL("relax.ExitDataflowScope").set_body_typed(DataflowScope::Internal::ExitScope); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op.cc b/src/relax/op/op.cc similarity index 100% rename from src/relax/op.cc rename to src/relax/op/op.cc diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h new file mode 100644 index 000000000000..4e1ce826f9b2 --- /dev/null +++ b/src/relax/op/op_common.h @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file op_common.h + * \brief A set of utilities and common functionality + * for Relax ops. + */ +#ifndef TVM_RELAX_OP_OP_COMMON_H_ +#define TVM_RELAX_OP_OP_COMMON_H_ + +#include +#include +#include + +namespace tvm { +namespace relax { + +/*! Quick helper macro + * - Expose a positional make function to construct the node. + * - Register op to the registry. + * + * We make the decision to always only expose positional argument. + * We will do rewrapping in the frontend to support language + * sugars such as keyword arguments and default value. + * + * \param OpName the name of registry. + */ +#define RELAX_REGISTER_BINARY_OP(OpName) \ + TVM_REGISTER_GLOBAL("relax.op." OpName).set_body_typed([](Expr lhs, Expr rhs) { \ + static const Op& op = Op::Get(OpName); \ + return Call(op, {lhs, rhs}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP("relax." OpName) \ + .set_num_inputs(2) \ + .add_argument("lhs", "Tensor", "The left hand side tensor.") \ + .add_argument("rhs", "Tensor", "The right hand side tensor.") + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_OP_COMMON_H_ diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc new file mode 100644 index 000000000000..df053027aff2 --- /dev/null +++ b/src/relax/op/tensor/binary.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file binary.cc + * \brief binary broadcast operators. + */ + +#include +#include +#include +#include +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +using Expr = tvm::RelayExpr; +using relay::Call; + +#define RELAX_BINARY_COMPUTE(FTOPI) \ + [](const Attrs& attrs, const Array& inputs, \ + const Type& out_type) -> Array { \ + ICHECK_EQ(inputs.size(), 2U); \ + return {FTOPI(inputs[0], inputs[1])}; \ + } + +RELAX_REGISTER_BINARY_OP("add") + .describe("Elementwise add with broadcasting") + .set_support_level(1); + +RELAX_REGISTER_BINARY_OP("multiply") + .describe("Elementwise multiply with broadcasting") + .set_support_level(1); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/vm/builtin.cc b/src/relax/vm/builtin.cc index b7655ba32b9e..5553a08108b9 100644 --- a/src/relax/vm/builtin.cc +++ b/src/relax/vm/builtin.cc @@ -69,11 +69,11 @@ TVM_REGISTER_GLOBAL("vm.builtin.make_shape") TVM_REGISTER_GLOBAL("vm.builtin.alloc_storage") .set_body_typed([](void* vm_state_ptr, Index size, Index alignment, Index device_type, - DLDataType dtype_hint) { + DLDataType dtype_hint) { VMState* vm_state = static_cast(vm_state_ptr); DLOG(INFO) << "AllocStorage: allocation_size=" << size << ", alignment=" << alignment - << ", dtype_hint=" << runtime::DLDataType2String(dtype_hint) - << ", device_type=" << device_type; + << ", dtype_hint=" << runtime::DLDataType2String(dtype_hint) + << ", device_type=" << device_type; auto storage_obj = runtime::SimpleObjAllocator().make_object(); ICHECK_LT(static_cast(device_type), vm_state->allocators.size()) diff --git a/tests/python/relax/test_irbuilder.py b/tests/python/relax/test_irbuilder.py new file mode 100644 index 000000000000..28bbc9756b67 --- /dev/null +++ b/tests/python/relax/test_irbuilder.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from tvm import tir +from tvm import relax as rx + + +def test_dataflow_block(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + dtype0 = rx.DynTensorType(rank=2, dtype="float16") + dtype1 = rx.DynTensorType(rank=1, dtype="float16") + x = rx.Var("x", [m, n], dtype0) + y = rx.Var("y", [n], dtype1) + ib = rx.IRBuilder() + with ib.dataflow() as df: + lv0 = ib.emit(rx.op.add(x, y)) + assert lv0.name_hint == "lv0" + lv1 = ib.emit(rx.op.multiply(lv0, y)) + assert lv1.name_hint == "lv1" + gv0 = ib.emit_output(lv1) + assert gv0.name_hint == "gv0" + blocks = ib.get_blocks() + assert len(blocks) == 1 + assert len(blocks[-1].bindings) == 3 + + +def test_function_single_block(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + dtype0 = rx.DynTensorType(rank=2, dtype="float16") + dtype1 = rx.DynTensorType(rank=1, dtype="float16") + x = rx.Var("x", [m, n], dtype0) + y = rx.Var("y", [n], dtype1) + ib = rx.IRBuilder() + with ib.function([x, y]): + with ib.dataflow() as df: + lv0 = ib.emit(rx.op.add(x, y)) + assert lv0.name_hint == "lv0" + lv1 = ib.emit(rx.op.multiply(lv0, y)) + assert lv1.name_hint == "lv1" + gv0 = ib.emit_output(lv1) + assert gv0.name_hint == "gv0" + ib.emit_output(gv0) + func = ib.get() + assert func.params[0] == x + assert func.params[1] == y + assert func.body.body == gv0 + assert len(func.body.blocks) == 1 + assert len(func.body.blocks[0].bindings) == 3 + + +def test_function_multi_blocks(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + dtype0 = rx.DynTensorType(rank=2, dtype="float16") + dtype1 = rx.DynTensorType(rank=1, dtype="float16") + x = rx.Var("x", [m, n], dtype0) + y = rx.Var("y", [n], dtype1) + ib = rx.IRBuilder() + with ib.function([x, y], "func"): + with ib.dataflow() as df: + lv0 = ib.emit(rx.op.add(x, y)) + assert lv0.name_hint == "lv0" + gv0 = ib.emit_output(lv0) + assert gv0.name_hint == "gv0" + gv1 = ib.emit(rx.op.add(gv0, gv0)) + assert gv1.name_hint == "gv1" + with ib.dataflow() as df: + lv0 = ib.emit(rx.op.add(gv1, gv1)) + assert lv0.name_hint == "lv0" + gv2 = ib.emit_output(gv1) + ib.emit_output(gv2) + func = ib.get() + assert func.params[0] == x + assert func.params[1] == y + assert func.name.name_hint == "func" + assert func.body.body == gv2 + assert len(func.body.blocks) == 3 + assert len(func.body.blocks[0].bindings) == 2 + assert len(func.body.blocks[1].bindings) == 1 + assert len(func.body.blocks[2].bindings) == 2 + + +if __name__ == "__main__": + test_dataflow_block() + test_function_single_block() + test_function_multi_blocks() diff --git a/tests/python/relax/test_op.py b/tests/python/relax/test_op.py index 8ef5abe5b4ac..9f708b89a42b 100644 --- a/tests/python/relax/test_op.py +++ b/tests/python/relax/test_op.py @@ -18,8 +18,6 @@ from tvm import tir from tvm import relax as rx from tvm.script import ty -from tvm.ir import TensorType -import numpy as np @tvm.register_func("test.op.identity") def identity_packed(a):