Skip to content

Commit

Permalink
stmt methods (apache#47)
Browse files Browse the repository at this point in the history
* `stmt` methods 0

* `stmt` methods 1

* `stmt` methods 2

* `stmt` methods 3

* `stmt` methods 4

* add `T.while` method

* `stmt` methods without `with`

* `IfFrame`, `ThenFrame`, `ElseFrame` as replacement

* apply code review suggestions 0

* apply code review suggestions 1

* apply code review suggestions 2

* apply code review suggestions
  • Loading branch information
cyx-6 authored and junrushao committed Jul 13, 2022
1 parent 0cbb858 commit f480eb2
Show file tree
Hide file tree
Showing 7 changed files with 828 additions and 9 deletions.
17 changes: 17 additions & 0 deletions python/tvm/script/builder/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,20 @@
prim_func,
)
from .var import Buffer
from .stmt import (
Assert,
let,
allocate,
allocate_const,
launch_thread,
realize,
attr,
while_,
if_,
then_,
else_,
env_thread,
buffer_store,
prefetch,
evaluate,
)
163 changes: 163 additions & 0 deletions python/tvm/script/builder/tir/stmt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM Script TIR For Frame"""
import numpy as np
from typing import List, Union

from tvm._ffi import register_object as _register_object
from tvm.tir import Buffer, IterVar, PrimExpr, Var, BufferRegion, Stmt, StringImm
from tvm.ir import Type, Range
from tvm.runtime import ndarray as nd, Object

from . import _ffi_api
from .. import _ffi_api as _base_ffi_api
from .base import TIRFrame


@_register_object("script.builder.tir.AssertFrame")
class AssertFrame(TIRFrame):
...


@_register_object("script.builder.tir.LetFrame")
class LetFrame(TIRFrame):
...


@_register_object("script.builder.tir.AllocateFrame")
class AllocateFrame(TIRFrame):
def __enter__(self) -> Buffer:
_base_ffi_api.FrameEnter(self) # pylint: disable=no-member # type: ignore
return self.buffer


@_register_object("script.builder.tir.AllocateConstFrame")
class AllocateConstFrame(TIRFrame):
def __enter__(self) -> Buffer:
_base_ffi_api.FrameEnter(self) # pylint: disable=no-member # type: ignore
return self.buffer


@_register_object("script.builder.tir.LaunchThreadFrame")
class LaunchThreadFrame(TIRFrame):
...


@_register_object("script.builder.tir.RealizeFrame")
class RealizeFrame(TIRFrame):
...


@_register_object("script.builder.tir.AttrFrame")
class AttrFrame(TIRFrame):
...


@_register_object("script.builder.tir.WhileFrame")
class WhileFrame(TIRFrame):
...


@_register_object("script.builder.tir.IfFrame")
class IfFrame(TIRFrame):
...


@_register_object("script.builder.tir.ThenFrame")
class ThenFrame(TIRFrame):
...


@_register_object("script.builder.tir.ElseFrame")
class ElseFrame(TIRFrame):
...


def Assert(condition: PrimExpr, message: str) -> AssertFrame:
return _ffi_api.AssertFrame(condition, message) # pylint: disable=no-member # type: ignore


def let(var: Var, value: PrimExpr) -> LetFrame:
return _ffi_api.LetFrame(var, value) # pylint: disable=no-member # type: ignore


def allocate(
extents: List[PrimExpr],
dtype: str,
storage_scope: str = "",
condition: PrimExpr = None,
annotations=None,
) -> AllocateFrame:
return _ffi_api.AllocateFrame(
extents, dtype, storage_scope, condition, annotations
) # pylint: disable=no-member # type: ignore


def allocate_const(data: List[PrimExpr], dtype: str, extents: List[PrimExpr]) -> AllocateConstFrame:
return _ffi_api.AllocateConstFrame(
nd.array(np.asarray(data, dtype)), dtype, extents
) # pylint: disable=no-member # type: ignore


def launch_thread(iter_var: IterVar, extent: PrimExpr) -> LaunchThreadFrame:
return _ffi_api.LaunchThreadFrame(iter_var, extent) # pylint: disable=no-member # type: ignore


def realize(
buffer_slice: BufferRegion, storage_scope: str, condition: PrimExpr = True
) -> RealizeFrame:
return _ffi_api.RealizeFrame(
buffer_slice, storage_scope, condition
) # pylint: disable=no-member # type: ignore


def attr(node: Object, attr_key: str, value: Union[PrimExpr, str]) -> AttrFrame:
if isinstance(value, str):
value = StringImm(value)
return _ffi_api.AttrFrame(node, attr_key, value) # pylint: disable=no-member # type: ignore


def while_(condition: PrimExpr) -> WhileFrame:
return _ffi_api.WhileFrame(condition) # pylint: disable=no-member # type: ignore


def if_(condition: PrimExpr) -> IfFrame:
return _ffi_api.IfFrame(condition) # pylint: disable=no-member # type: ignore


def then_() -> ThenFrame:
return _ffi_api.ThenFrame() # pylint: disable=no-member # type: ignore


def else_() -> ElseFrame:
return _ffi_api.ElseFrame() # pylint: disable=no-member # type: ignore


def env_thread(thread_tag: str) -> IterVar:
return _ffi_api.EnvThread(thread_tag) # pylint: disable=no-member # type: ignore


def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[PrimExpr]) -> None:
return _ffi_api.BufferStore(buffer, value, indices) # pylint: disable=no-member # type: ignore


def prefetch(buffer: Buffer, indices: List[PrimExpr]) -> None:
return _ffi_api.Prefetch(buffer, indices) # pylint: disable=no-member # type: ignore


def evaluate(value: PrimExpr) -> None:
return _ffi_api.Evaluate(value) # pylint: disable=no-member # type: ignore
14 changes: 8 additions & 6 deletions src/script/builder/tir/block_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ void BlockFrameNode::ExitWithScope() {
Block block = Block(iter_vars, reads, writes, name, AsStmt(stmts), init, alloc_buffers,
match_buffers, annotations);
if (no_realize) {
CHECK(iter_values.empty()) << "ValueError: Block bindings are not allowed when `no_realize=True`";
CHECK(iter_values.empty())
<< "ValueError: Block bindings are not allowed when `no_realize=True`";
CHECK(!predicate.defined()) << "ValueError: `T.where` is not allowed when `no_realize=True`";
AddToParent(block);
} else {
Expand All @@ -68,7 +69,7 @@ BlockInitFrame Init() {
void BlockInitFrameNode::EnterWithScope() {
BlockFrame frame = FindBlockFrame("T.init");
if (frame->init.defined()) {
LOG(FATAL) << "Duplicate block init declaration";
LOG(FATAL) << "ValueError: Duplicate block init declaration";
}
TIRFrameNode::EnterWithScope();
}
Expand All @@ -92,7 +93,7 @@ BlockFrame FindBlockFrame(const String& method) {
void Where(PrimExpr predicate) {
BlockFrame frame = FindBlockFrame("T.where");
if (frame->predicate.defined()) {
LOG(FATAL) << "Duplicate block predicate declaration, previous one is "
LOG(FATAL) << "ValueError: Duplicate block predicate declaration, previous one is "
<< frame->predicate.value();
}
frame->predicate = predicate;
Expand All @@ -102,7 +103,7 @@ void Reads(Array<ObjectRef> buffer_slices) {
using namespace tvm::tir;
BlockFrame frame = FindBlockFrame("T.reads");
if (!frame->reads.empty()) {
LOG(FATAL) << "Duplicate read region declaration, previous one is " << frame->reads;
LOG(FATAL) << "ValueError: Duplicate read region declaration, previous one is " << frame->reads;
}
for (const ObjectRef& obj : buffer_slices) {
if (const auto* buffer_region = obj.as<BufferRegionNode>()) {
Expand All @@ -119,7 +120,8 @@ void Writes(Array<ObjectRef> buffer_slices) {
using namespace tvm::tir;
BlockFrame frame = FindBlockFrame("T.writes");
if (!frame->writes.empty()) {
LOG(FATAL) << "Duplicate write region declaration, previous one is " << frame->writes;
LOG(FATAL) << "ValueError: Duplicate write region declaration, previous one is "
<< frame->writes;
}
for (const ObjectRef& obj : buffer_slices) {
if (const auto* buffer_region = obj.as<BufferRegionNode>()) {
Expand All @@ -135,7 +137,7 @@ void Writes(Array<ObjectRef> buffer_slices) {
void BlockAttrs(Map<String, ObjectRef> attrs) {
BlockFrame frame = FindBlockFrame("T.block_attr");
if (!frame->annotations.empty()) {
LOG(FATAL) << "Duplicate block annotations, previous one is " << frame->annotations;
LOG(FATAL) << "ValueError: Duplicate block annotations, previous one is " << frame->annotations;
}
frame->annotations = attrs;
}
Expand Down
7 changes: 4 additions & 3 deletions src/script/builder/tir/prim_func_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ tvm::tir::Buffer Arg(String name, tvm::tir::Buffer buffer) {
void FuncName(String name) {
PrimFuncFrame frame = FindPrimFuncFrame("T.func_name");
if (frame->name.defined()) {
LOG(FATAL) << "Duplicate prim func name, previous one is " << frame->name.value();
LOG(FATAL) << "ValueError: Duplicate prim func name, previous one is " << frame->name.value();
}
frame->name = name;
}
Expand All @@ -125,15 +125,16 @@ void FuncAttrs(Map<String, ObjectRef> attrs) {
using namespace tvm::tir;
PrimFuncFrame frame = FindPrimFuncFrame("T.func_attr");
if (!frame->attrs.empty()) {
LOG(FATAL) << "Duplicate prim func annotations, previous one is " << frame->attrs;
LOG(FATAL) << "ValueError: Duplicate prim func annotations, previous one is " << frame->attrs;
}
frame->attrs = attrs;
}

tvm::Type FuncRet(tvm::Type ret_type) {
PrimFuncFrame frame = FindPrimFuncFrame("T.ret_type");
if (frame->ret_type.defined()) {
LOG(FATAL) << "Duplicate prim func return type, previous one is " << frame->ret_type.value();
LOG(FATAL) << "ValueError: Duplicate prim func return type, previous one is "
<< frame->ret_type.value();
}
frame->ret_type = ret_type;
return ret_type;
Expand Down
Loading

0 comments on commit f480eb2

Please sign in to comment.