diff --git a/python/tvm/script/builder/tir/__init__.py b/python/tvm/script/builder/tir/__init__.py index 48046f0e9175..8c9b7de76d80 100644 --- a/python/tvm/script/builder/tir/__init__.py +++ b/python/tvm/script/builder/tir/__init__.py @@ -40,3 +40,20 @@ prim_func, ) from .var import Buffer +from .stmt import ( + Assert, + let, + allocate, + allocate_const, + launch_thread, + realize, + attr, + while_, + if_, + then_, + else_, + env_thread, + buffer_store, + prefetch, + evaluate, +) diff --git a/python/tvm/script/builder/tir/stmt.py b/python/tvm/script/builder/tir/stmt.py new file mode 100644 index 000000000000..5d3ca58b7b9f --- /dev/null +++ b/python/tvm/script/builder/tir/stmt.py @@ -0,0 +1,163 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVM Script TIR For Frame""" +import numpy as np +from typing import List, Union + +from tvm._ffi import register_object as _register_object +from tvm.tir import Buffer, IterVar, PrimExpr, Var, BufferRegion, Stmt, StringImm +from tvm.ir import Type, Range +from tvm.runtime import ndarray as nd, Object + +from . import _ffi_api +from .. import _ffi_api as _base_ffi_api +from .base import TIRFrame + + +@_register_object("script.builder.tir.AssertFrame") +class AssertFrame(TIRFrame): + ... + + +@_register_object("script.builder.tir.LetFrame") +class LetFrame(TIRFrame): + ... + + +@_register_object("script.builder.tir.AllocateFrame") +class AllocateFrame(TIRFrame): + def __enter__(self) -> Buffer: + _base_ffi_api.FrameEnter(self) # pylint: disable=no-member # type: ignore + return self.buffer + + +@_register_object("script.builder.tir.AllocateConstFrame") +class AllocateConstFrame(TIRFrame): + def __enter__(self) -> Buffer: + _base_ffi_api.FrameEnter(self) # pylint: disable=no-member # type: ignore + return self.buffer + + +@_register_object("script.builder.tir.LaunchThreadFrame") +class LaunchThreadFrame(TIRFrame): + ... + + +@_register_object("script.builder.tir.RealizeFrame") +class RealizeFrame(TIRFrame): + ... + + +@_register_object("script.builder.tir.AttrFrame") +class AttrFrame(TIRFrame): + ... + + +@_register_object("script.builder.tir.WhileFrame") +class WhileFrame(TIRFrame): + ... + + +@_register_object("script.builder.tir.IfFrame") +class IfFrame(TIRFrame): + ... + + +@_register_object("script.builder.tir.ThenFrame") +class ThenFrame(TIRFrame): + ... + + +@_register_object("script.builder.tir.ElseFrame") +class ElseFrame(TIRFrame): + ... + + +def Assert(condition: PrimExpr, message: str) -> AssertFrame: + return _ffi_api.AssertFrame(condition, message) # pylint: disable=no-member # type: ignore + + +def let(var: Var, value: PrimExpr) -> LetFrame: + return _ffi_api.LetFrame(var, value) # pylint: disable=no-member # type: ignore + + +def allocate( + extents: List[PrimExpr], + dtype: str, + storage_scope: str = "", + condition: PrimExpr = None, + annotations=None, +) -> AllocateFrame: + return _ffi_api.AllocateFrame( + extents, dtype, storage_scope, condition, annotations + ) # pylint: disable=no-member # type: ignore + + +def allocate_const(data: List[PrimExpr], dtype: str, extents: List[PrimExpr]) -> AllocateConstFrame: + return _ffi_api.AllocateConstFrame( + nd.array(np.asarray(data, dtype)), dtype, extents + ) # pylint: disable=no-member # type: ignore + + +def launch_thread(iter_var: IterVar, extent: PrimExpr) -> LaunchThreadFrame: + return _ffi_api.LaunchThreadFrame(iter_var, extent) # pylint: disable=no-member # type: ignore + + +def realize( + buffer_slice: BufferRegion, storage_scope: str, condition: PrimExpr = True +) -> RealizeFrame: + return _ffi_api.RealizeFrame( + buffer_slice, storage_scope, condition + ) # pylint: disable=no-member # type: ignore + + +def attr(node: Object, attr_key: str, value: Union[PrimExpr, str]) -> AttrFrame: + if isinstance(value, str): + value = StringImm(value) + return _ffi_api.AttrFrame(node, attr_key, value) # pylint: disable=no-member # type: ignore + + +def while_(condition: PrimExpr) -> WhileFrame: + return _ffi_api.WhileFrame(condition) # pylint: disable=no-member # type: ignore + + +def if_(condition: PrimExpr) -> IfFrame: + return _ffi_api.IfFrame(condition) # pylint: disable=no-member # type: ignore + + +def then_() -> ThenFrame: + return _ffi_api.ThenFrame() # pylint: disable=no-member # type: ignore + + +def else_() -> ElseFrame: + return _ffi_api.ElseFrame() # pylint: disable=no-member # type: ignore + + +def env_thread(thread_tag: str) -> IterVar: + return _ffi_api.EnvThread(thread_tag) # pylint: disable=no-member # type: ignore + + +def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[PrimExpr]) -> None: + return _ffi_api.BufferStore(buffer, value, indices) # pylint: disable=no-member # type: ignore + + +def prefetch(buffer: Buffer, indices: List[PrimExpr]) -> None: + return _ffi_api.Prefetch(buffer, indices) # pylint: disable=no-member # type: ignore + + +def evaluate(value: PrimExpr) -> None: + return _ffi_api.Evaluate(value) # pylint: disable=no-member # type: ignore diff --git a/src/script/builder/tir/block_frame.cc b/src/script/builder/tir/block_frame.cc index 663017624e09..6566809d6eba 100644 --- a/src/script/builder/tir/block_frame.cc +++ b/src/script/builder/tir/block_frame.cc @@ -52,7 +52,8 @@ void BlockFrameNode::ExitWithScope() { Block block = Block(iter_vars, reads, writes, name, AsStmt(stmts), init, alloc_buffers, match_buffers, annotations); if (no_realize) { - CHECK(iter_values.empty()) << "ValueError: Block bindings are not allowed when `no_realize=True`"; + CHECK(iter_values.empty()) + << "ValueError: Block bindings are not allowed when `no_realize=True`"; CHECK(!predicate.defined()) << "ValueError: `T.where` is not allowed when `no_realize=True`"; AddToParent(block); } else { @@ -68,7 +69,7 @@ BlockInitFrame Init() { void BlockInitFrameNode::EnterWithScope() { BlockFrame frame = FindBlockFrame("T.init"); if (frame->init.defined()) { - LOG(FATAL) << "Duplicate block init declaration"; + LOG(FATAL) << "ValueError: Duplicate block init declaration"; } TIRFrameNode::EnterWithScope(); } @@ -92,7 +93,7 @@ BlockFrame FindBlockFrame(const String& method) { void Where(PrimExpr predicate) { BlockFrame frame = FindBlockFrame("T.where"); if (frame->predicate.defined()) { - LOG(FATAL) << "Duplicate block predicate declaration, previous one is " + LOG(FATAL) << "ValueError: Duplicate block predicate declaration, previous one is " << frame->predicate.value(); } frame->predicate = predicate; @@ -102,7 +103,7 @@ void Reads(Array buffer_slices) { using namespace tvm::tir; BlockFrame frame = FindBlockFrame("T.reads"); if (!frame->reads.empty()) { - LOG(FATAL) << "Duplicate read region declaration, previous one is " << frame->reads; + LOG(FATAL) << "ValueError: Duplicate read region declaration, previous one is " << frame->reads; } for (const ObjectRef& obj : buffer_slices) { if (const auto* buffer_region = obj.as()) { @@ -119,7 +120,8 @@ void Writes(Array buffer_slices) { using namespace tvm::tir; BlockFrame frame = FindBlockFrame("T.writes"); if (!frame->writes.empty()) { - LOG(FATAL) << "Duplicate write region declaration, previous one is " << frame->writes; + LOG(FATAL) << "ValueError: Duplicate write region declaration, previous one is " + << frame->writes; } for (const ObjectRef& obj : buffer_slices) { if (const auto* buffer_region = obj.as()) { @@ -135,7 +137,7 @@ void Writes(Array buffer_slices) { void BlockAttrs(Map attrs) { BlockFrame frame = FindBlockFrame("T.block_attr"); if (!frame->annotations.empty()) { - LOG(FATAL) << "Duplicate block annotations, previous one is " << frame->annotations; + LOG(FATAL) << "ValueError: Duplicate block annotations, previous one is " << frame->annotations; } frame->annotations = attrs; } diff --git a/src/script/builder/tir/prim_func_frame.cc b/src/script/builder/tir/prim_func_frame.cc index 5d1d2ae9defb..beebf23f6a98 100644 --- a/src/script/builder/tir/prim_func_frame.cc +++ b/src/script/builder/tir/prim_func_frame.cc @@ -116,7 +116,7 @@ tvm::tir::Buffer Arg(String name, tvm::tir::Buffer buffer) { void FuncName(String name) { PrimFuncFrame frame = FindPrimFuncFrame("T.func_name"); if (frame->name.defined()) { - LOG(FATAL) << "Duplicate prim func name, previous one is " << frame->name.value(); + LOG(FATAL) << "ValueError: Duplicate prim func name, previous one is " << frame->name.value(); } frame->name = name; } @@ -125,7 +125,7 @@ void FuncAttrs(Map attrs) { using namespace tvm::tir; PrimFuncFrame frame = FindPrimFuncFrame("T.func_attr"); if (!frame->attrs.empty()) { - LOG(FATAL) << "Duplicate prim func annotations, previous one is " << frame->attrs; + LOG(FATAL) << "ValueError: Duplicate prim func annotations, previous one is " << frame->attrs; } frame->attrs = attrs; } @@ -133,7 +133,8 @@ void FuncAttrs(Map attrs) { tvm::Type FuncRet(tvm::Type ret_type) { PrimFuncFrame frame = FindPrimFuncFrame("T.ret_type"); if (frame->ret_type.defined()) { - LOG(FATAL) << "Duplicate prim func return type, previous one is " << frame->ret_type.value(); + LOG(FATAL) << "ValueError: Duplicate prim func return type, previous one is " + << frame->ret_type.value(); } frame->ret_type = ret_type; return ret_type; diff --git a/src/script/builder/tir/stmt.cc b/src/script/builder/tir/stmt.cc new file mode 100644 index 000000000000..60e788109e46 --- /dev/null +++ b/src/script/builder/tir/stmt.cc @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./stmt.h" + +#include +#include +#include + +#include "./prim_func_frame.h" +#include "./var.h" + +namespace tvm { +namespace script { +namespace builder { +namespace tir { + +void AssertFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(tvm::tir::AssertStmt(condition, message, AsStmt(stmts))); +} + +void LetFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(tvm::tir::LetStmt(var, value, AsStmt(stmts))); +} + +void AllocateFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + tvm::tir::Buffer flattened_buffer = buffer.GetFlattenedBuffer(); + AddToParent(tvm::tir::Allocate(buffer->data, flattened_buffer->dtype, flattened_buffer->shape, + condition, AsStmt(stmts), annotations)); +} + +void AllocateConstFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(tvm::tir::AllocateConst(buffer->data, dtype, extents, data, AsStmt(stmts))); +} + +void LaunchThreadFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(tvm::tir::AttrStmt(iter_var, attr_key, extent, AsStmt(stmts))); +} + +void RealizeFrameNode::ExitWithScope() { + using namespace tvm::tir; + TIRFrameNode::ExitWithScope(); + AddToParent(AttrStmt( + buffer_slice->buffer, "realize_scope", StringImm(storage_scope), + BufferRealize(buffer_slice->buffer, buffer_slice->region, condition, AsStmt(stmts)))); +} + +void AttrFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(tvm::tir::AttrStmt(node, attr_key, value, AsStmt(stmts))); +} + +void WhileFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(tvm::tir::While(condition, AsStmt(stmts))); +} + +void IfFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + if (!stmts.empty()) { + LOG(FATAL) << "stmt within IfThenElse frame should be either in ThenFrame or ElseFrame"; + } + if (!then_stmts.defined()) { + LOG(FATAL) << "IfThenElse frame should have at least one then branch"; + } + AddToParent(tvm::tir::IfThenElse( + condition, AsStmt(then_stmts.value()), + else_stmts.defined() ? AsStmt(else_stmts.value()) : tvm::tir::Stmt(nullptr))); +} + +IfFrame FindIfFrame(const String& method) { + if (Optional if_frame = Builder::Current()->GetLastFrame()) { + return if_frame.value(); + } else { + LOG(FATAL) << "ValueError: IfThenElse frame not find. Please ensure '" << method + << "' is called under T.if_()"; + } + throw; +} + +void ThenFrameNode::EnterWithScope() { + IfFrame frame = FindIfFrame("T.then_"); + if (frame->then_stmts.defined()) { + LOG(FATAL) << "ValueError: Duplicate then branch declaration, previous one is " + << frame->then_stmts.value(); + } + TIRFrameNode::EnterWithScope(); +} + +void ThenFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + FindIfFrame("T.then_")->then_stmts = stmts; +} + +void ElseFrameNode::EnterWithScope() { + IfFrame frame = FindIfFrame("T.else_"); + if (!frame->then_stmts.defined()) { + LOG(FATAL) << "The else branch should follow then branch"; + } + if (frame->else_stmts.defined()) { + LOG(FATAL) << "ValueError: Duplicate else branch declaration, previous one is " + << frame->else_stmts.value(); + } + TIRFrameNode::EnterWithScope(); +} + +void ElseFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + FindIfFrame("T.else_")->else_stmts = stmts; +} + +AssertFrame Assert(PrimExpr condition, String message) { + ObjectPtr n = make_object(); + n->condition = condition; + n->message = tvm::tir::StringImm(message); + return AssertFrame(n); +} + +LetFrame Let(tvm::tir::Var var, PrimExpr value) { + ObjectPtr n = make_object(); + n->var = var; + n->value = value; + return LetFrame(n); +} + +AllocateFrame Allocate(Array extents, DataType dtype, String storage_scope, + Optional condition, Optional> annotations) { + ObjectPtr n = make_object(); + n->extents = extents; + n->dtype = dtype; + n->storage_scope = storage_scope; + n->condition = condition.value_or(true); + if (!n->condition->dtype.is_bool()) { + n->condition = tvm::cast(DataType::Bool(), n->condition); + } + n->annotations = annotations.value_or(Map()); + n->buffer = DeclBuffer(extents, dtype, "", NullOpt, {}, PrimExpr(), storage_scope, 0, 0, + "default", {}, Span()); + return AllocateFrame(n); +} + +AllocateConstFrame AllocateConst(tvm::runtime::NDArray data, DataType dtype, + Array extents) { + ObjectPtr n = make_object(); + n->dtype = dtype; + n->extents = extents; + n->data = data; + n->buffer = + DeclBuffer(extents, dtype, "", NullOpt, {}, PrimExpr(), "", 0, 0, "default", {}, Span()); + return AllocateConstFrame(n); +} + +LaunchThreadFrame LaunchThread(tvm::tir::IterVar iter_var, PrimExpr extent) { + ObjectPtr n = make_object(); + if (!iter_var->dom.defined()) { + const_cast(iter_var.get())->dom = Range(0, extent); + } else if (!arith::Analyzer().CanProveEqual(iter_var->dom->extent, extent)) { + LOG(FATAL) << "ValueError: Inconsistent extents of environment thread. " + << iter_var->dom->extent << " vs " << extent; + } + n->iter_var = iter_var; + n->extent = extent; + n->attr_key = iter_var->thread_tag == "vthread" ? "virtual_thread" : "thread_extent"; + return LaunchThreadFrame(n); +} + +RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, + PrimExpr condition) { + ObjectPtr n = make_object(); + n->buffer_slice = buffer_slice; + n->storage_scope = storage_scope; + n->condition = condition; + return RealizeFrame(n); +} + +AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value) { + ObjectPtr n = make_object(); + n->node = node; + n->attr_key = attr_key; + n->value = value; + return AttrFrame(n); +} + +WhileFrame While(PrimExpr condition) { + ObjectPtr n = make_object(); + n->condition = condition; + return WhileFrame(n); +} + +IfFrame If(PrimExpr condition) { + ObjectPtr n = make_object(); + n->condition = condition; + n->then_stmts = NullOpt; + n->else_stmts = NullOpt; + return IfFrame(n); +} + +ThenFrame Then() { + ObjectPtr n = make_object(); + return ThenFrame(n); +} + +ElseFrame Else() { + ObjectPtr n = make_object(); + return ElseFrame(n); +} + +tvm::tir::IterVar EnvThread(String thread_tag) { + using namespace tvm::tir; + return IterVar(Range{nullptr}, Var("", DataType::Int(32)), IterVarType::kThreadIndex, thread_tag); +} + +void BufferStore(tvm::tir::Buffer buffer, PrimExpr value, Array indices) { + AddToParent(tvm::tir::BufferStore(buffer, value, indices)); +} + +void Prefetch(tvm::tir::Buffer buffer, Array bounds) { + AddToParent(tvm::tir::Prefetch(buffer, bounds)); +} + +void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); } + +TVM_REGISTER_NODE_TYPE(AssertFrameNode); +TVM_REGISTER_NODE_TYPE(LetFrameNode); +TVM_REGISTER_NODE_TYPE(AllocateFrameNode); +TVM_REGISTER_NODE_TYPE(AllocateConstFrameNode); +TVM_REGISTER_NODE_TYPE(LaunchThreadFrameNode); +TVM_REGISTER_NODE_TYPE(RealizeFrameNode); +TVM_REGISTER_NODE_TYPE(AttrFrameNode); +TVM_REGISTER_NODE_TYPE(WhileFrameNode); +TVM_REGISTER_NODE_TYPE(IfFrameNode); +TVM_REGISTER_NODE_TYPE(ThenFrameNode); +TVM_REGISTER_NODE_TYPE(ElseFrameNode); +TVM_REGISTER_GLOBAL("script.builder.tir.AssertFrame").set_body_typed(Assert); +TVM_REGISTER_GLOBAL("script.builder.tir.LetFrame").set_body_typed(Let); +TVM_REGISTER_GLOBAL("script.builder.tir.AllocateFrame").set_body_typed(Allocate); +TVM_REGISTER_GLOBAL("script.builder.tir.AllocateConstFrame").set_body_typed(AllocateConst); +TVM_REGISTER_GLOBAL("script.builder.tir.RealizeFrame").set_body_typed(Realize); +TVM_REGISTER_GLOBAL("script.builder.tir.AttrFrame").set_body_typed(Attr); +TVM_REGISTER_GLOBAL("script.builder.tir.WhileFrame").set_body_typed(While); +TVM_REGISTER_GLOBAL("script.builder.tir.IfFrame").set_body_typed(If); +TVM_REGISTER_GLOBAL("script.builder.tir.ThenFrame").set_body_typed(Then); +TVM_REGISTER_GLOBAL("script.builder.tir.ElseFrame").set_body_typed(Else); +TVM_REGISTER_GLOBAL("script.builder.tir.LaunchThreadFrame").set_body_typed(LaunchThread); +TVM_REGISTER_GLOBAL("script.builder.tir.EnvThread").set_body_typed(EnvThread); +TVM_REGISTER_GLOBAL("script.builder.tir.BufferStore").set_body_typed(BufferStore); +TVM_REGISTER_GLOBAL("script.builder.tir.Prefetch").set_body_typed(Prefetch); +TVM_REGISTER_GLOBAL("script.builder.tir.Evaluate").set_body_typed(Evaluate); + +} // namespace tir +} // namespace builder +} // namespace script +} // namespace tvm diff --git a/src/script/builder/tir/stmt.h b/src/script/builder/tir/stmt.h new file mode 100644 index 000000000000..d378431c42fb --- /dev/null +++ b/src/script/builder/tir/stmt.h @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_BUILDER_TIR_STMT_H_ +#define TVM_SCRIPT_BUILDER_TIR_STMT_H_ + +#include "./base.h" + +namespace tvm { +namespace script { +namespace builder { +namespace tir { + +class AssertFrameNode : public TIRFrameNode { + public: + PrimExpr condition; + PrimExpr message; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("condition", &condition); + v->Visit("message", &message); + } + + static constexpr const char* _type_key = "script.builder.tir.AssertFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(AssertFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; +}; + +class AssertFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AssertFrame, TIRFrame, AssertFrameNode); +}; + +class LetFrameNode : public TIRFrameNode { + public: + tvm::tir::Var var; + PrimExpr value; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("var", &var); + v->Visit("value", &value); + } + + static constexpr const char* _type_key = "script.builder.tir.LetFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(LetFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; +}; + +class LetFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LetFrame, TIRFrame, LetFrameNode); +}; + +class AllocateFrameNode : public TIRFrameNode { + public: + Array extents; + DataType dtype; + String storage_scope; + PrimExpr condition; + Map annotations; + tvm::tir::Buffer buffer; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("extents", &extents); + v->Visit("dtype", &dtype); + v->Visit("storage_scope", &storage_scope); + v->Visit("condition", &condition); + v->Visit("annotations", &annotations); + v->Visit("buffer", &buffer); + } + + static constexpr const char* _type_key = "script.builder.tir.AllocateFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(AllocateFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; +}; + +class AllocateFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateFrame, TIRFrame, AllocateFrameNode); +}; + +class AllocateConstFrameNode : public TIRFrameNode { + public: + DataType dtype; + Array extents; + tvm::runtime::NDArray data; + tvm::tir::Buffer buffer; + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("dtype", &dtype); + v->Visit("extents", &extents); + v->Visit("data", &data); + v->Visit("buffer", &buffer); + } + + static constexpr const char* _type_key = "script.builder.tir.AllocateConstFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; +}; + +class AllocateConstFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateConstFrame, TIRFrame, + AllocateConstFrameNode); +}; + +class LaunchThreadFrameNode : public TIRFrameNode { + public: + PrimExpr extent; + String attr_key; + tvm::tir::IterVar iter_var; + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("extent", &extent); + v->Visit("attr_key", &attr_key); + v->Visit("iter_var", &iter_var); + } + + static constexpr const char* _type_key = "script.builder.tir.LaunchThreadFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(LaunchThreadFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; +}; + +class LaunchThreadFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LaunchThreadFrame, TIRFrame, + LaunchThreadFrameNode); +}; + +class RealizeFrameNode : public TIRFrameNode { + public: + tvm::tir::BufferRegion buffer_slice; + String storage_scope; + PrimExpr condition; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("buffer_slice", &buffer_slice); + v->Visit("storage_scope", &storage_scope); + v->Visit("condition", &condition); + } + + static constexpr const char* _type_key = "script.builder.tir.RealizeFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(RealizeFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; +}; + +class RealizeFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame, RealizeFrameNode); +}; + +class AttrFrameNode : public TIRFrameNode { + public: + ObjectRef node; + String attr_key; + PrimExpr value; + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("node", &node); + v->Visit("attr_key", &attr_key); + v->Visit("value", &value); + } + + static constexpr const char* _type_key = "script.builder.tir.AttrFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(AttrFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; +}; + +class AttrFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AttrFrame, TIRFrame, AttrFrameNode); +}; + +class WhileFrameNode : public TIRFrameNode { + public: + PrimExpr condition; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("condition", &condition); + } + + static constexpr const char* _type_key = "script.builder.tir.WhileFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(WhileFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; +}; + +class WhileFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WhileFrame, TIRFrame, WhileFrameNode); +}; + +class IfFrameNode : public TIRFrameNode { + public: + PrimExpr condition; + Optional> then_stmts; + Optional> else_stmts; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("condition", &condition); + v->Visit("then_stmts", &then_stmts); + v->Visit("else_stmts", &else_stmts); + } + + static constexpr const char* _type_key = "script.builder.tir.IfFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(IfFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; +}; + +class IfFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, TIRFrame, IfFrameNode); +}; + +class ThenFrameNode : public TIRFrameNode { + public: + static constexpr const char* _type_key = "script.builder.tir.ThenFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(ThenFrameNode, TIRFrameNode); + + public: + void EnterWithScope() final; + void ExitWithScope() final; +}; + +class ThenFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, TIRFrame, ThenFrameNode); +}; + +class ElseFrameNode : public TIRFrameNode { + public: + static constexpr const char* _type_key = "script.builder.tir.ElseFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(ElseFrameNode, TIRFrameNode); + + public: + void EnterWithScope() final; + void ExitWithScope() final; +}; + +class ElseFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, TIRFrame, ElseFrameNode); +}; + +tvm::tir::IterVar EnvThread(String thread_tag); +void BufferStore(tvm::tir::Buffer buffer, PrimExpr value, Array indices); +void Prefetch(tvm::tir::Buffer buffer, Array bounds); +void Evaluate(PrimExpr value); + +AssertFrame Assert(PrimExpr condition, String message); +LetFrame Let(tvm::tir::Var var, PrimExpr value); +AllocateFrame Allocate(Array extents, DataType dtype, String storage_scope = "", + Optional condition = NullOpt, + Optional> annotations = NullOpt); +AllocateConstFrame AllocateConst(tvm::runtime::NDArray data, DataType dtype, + Array extents); +LaunchThreadFrame LaunchThread(tvm::tir::IterVar iter_var, PrimExpr extent); +RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition); +AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value); +WhileFrame While(PrimExpr condition); +IfFrame If(PrimExpr condition); +ThenFrame Then(); +ElseFrame Else(); +} // namespace tir +} // namespace builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_BUILDER_TIR_STMT_H_ diff --git a/tests/python/tvmscript/test_builder_basic.py b/tests/python/tvmscript/test_builder_basic.py index 035c5034b0ca..2fd63967ea09 100644 --- a/tests/python/tvmscript/test_builder_basic.py +++ b/tests/python/tvmscript/test_builder_basic.py @@ -173,9 +173,65 @@ def test_builder_for(): print(b.get().script()) +def test_builder_stmt(): + print("test_builder_stmt") + with Builder() as b: + with T.prim_func(): + thread_x = def_("thread_x", T.env_thread("threadIdx.x")) + thread_y = def_("thread_y", T.env_thread("threadIdx.y")) + buffer_x = def_("buffer_x", tvm.tir.decl_buffer([128, 128])) + buffer_y = def_("buffer_y", tvm.tir.decl_buffer([128, 128])) + var_x = def_("var_x", tvm.tir.Var("", dtype="int32")) + var_y = def_("var_y", tvm.tir.Var("", dtype="int32")) + with T.Assert(var_x < var_y, ""): + with T.Assert(1, "true"): + pass + with T.let(var_x, var_y): + pass + with T.allocate([128], "uint8", "global") as alloc_x: + with T.allocate([128], "uint8", "global") as alloc_y: + alloc_x, alloc_y = def_many(["alloc_x", "alloc_y"], [alloc_x, alloc_y]) + with T.allocate_const([1, 1, 1, 1, 1], "int32", [5]) as alloc_const_x: + with T.allocate_const([10, 10, 10], "float32", [3]) as alloc_const_y: + alloc_const_x, alloc_const_y = def_many( + ["alloc_const_x", "alloc_const_y"], [alloc_const_x, alloc_const_y] + ) + with T.realize(BufferRegion(buffer_x, [Range(0, var_x), Range(0, var_y)]), ""): + with T.realize(BufferRegion(buffer_y, [Range(var_x, 128), Range(var_y, 128)]), ""): + pass + with T.attr(buffer_x, "key_x", "value_x"): + with T.attr(buffer_y, "key_y", "value_y"): + pass + with T.launch_thread(thread_x, 4): + with T.launch_thread(thread_y, 4): + pass + with T.while_(var_x < var_y): + with T.while_(var_x > 0): + pass + with T.if_(var_x < var_y): + with T.then_(): + T.evaluate(0) + T.evaluate(1) + with T.else_(): + T.evaluate(0) + T.evaluate(1) + with T.if_(1): + with T.then_(): + T.evaluate(1) + T.prefetch(buffer_x, [Range(0, 64), Range(64, 128)]) + T.prefetch(buffer_y, [Range(0, var_x), Range(var_y, 128)]) + T.buffer_store(buffer_x, 1, [0, 0]) + T.buffer_store(buffer_x, var_x + var_y, [var_x, var_y]) + T.evaluate(var_x + var_y) + T.evaluate(1) + + print(b.get().script()) + + if __name__ == "__main__": test_builder_root_block() test_builder_axis() test_builder_prim_func() test_builder_block() test_builder_for() + test_builder_stmt()