diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py new file mode 100644 index 000000000000..aaa0e9c9174d --- /dev/null +++ b/python/tvm/ansor/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import, redefined-builtin +"""Namespace for Ansor autoSchedule""" + +from .compute_dag import ComputeDAG diff --git a/python/tvm/ansor/_ffi_api.py b/python/tvm/ansor/_ffi_api.py new file mode 100644 index 000000000000..177299e67d21 --- /dev/null +++ b/python/tvm/ansor/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.ansor""" +import tvm._ffi + + +tvm._ffi._init_api("ansor", __name__) diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py new file mode 100644 index 000000000000..3c46440f75ba --- /dev/null +++ b/python/tvm/ansor/compute_dag.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +""" ... """ + +import tvm._ffi +from tvm.runtime import Object + +from .state import State + +from . import _ffi_api + + +@tvm._ffi.register_object("ansor.ComputeDAG") +class ComputeDAG(Object): + def __init__(self, tensors): + self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, tensors) + + def get_init_state(self) -> State: + return self.init_state diff --git a/python/tvm/ansor/state.py b/python/tvm/ansor/state.py new file mode 100644 index 000000000000..9a8810190199 --- /dev/null +++ b/python/tvm/ansor/state.py @@ -0,0 +1,387 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +""" ... """ + +import tvm._ffi +from tvm.runtime import Object + +from . import _ffi_api + + +@tvm._ffi.register_object("ansor.Iterator") +class Iterator(Object): + pass + + +@tvm._ffi.register_object("ansor.Stage") +class Stage(Object): + + def iterator(self, index): + return _ffi_api.StageGetIterator(self, index) + + def iterators(self): + return _ffi_api.StageGetIterators(self) + + +@tvm._ffi.register_object("ansor.State") +class State(Object): + + def stage(self, index): + """ + Parameters + ---------- + index : Int + + Returns + ------- + stage : Stage + """ + return _ffi_api.StateGetStage(self, index) + + def transform_steps_size(self): + """ Return the size of transform_steps + """ + return _ffi_api.StateGetTransformStepsSize(self) + + def reorder(self, stage_id, order): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + order : List[Iterator] + Iterators in expected order + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateReorder(self, stage_id, order) + return state + + def split(self, stage_id, it, lengths, inner_to_outer=True): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + The target Iterator + lengths: List[Int] + The split factor + inner_to_outer: Bool + True to use `factor` for split from inner to outer, + False to use `nparts` for split from outer to inner + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateSplit(self, stage_id, it, lengths, + inner_to_outer) + return state + + def follow_split(self, stage_id, it, src_step_id, n_split): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + The target Iterator + src_step_id : Int + The index of target step that this split follows + n_split : Int + Indecate how many level needs to be split out + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateFollowSplit(self, stage_id, it, src_step_id, + n_split) + return state + + def follow_fused_split(self, stage_id, it, src_step_ids, level, + factor_or_nparts): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + The target Iterator + src_step_ids : List[Int] + The indexes of target step that this split follows + level : Int + factor_or_nparts : Bool + True to use `factor` for split from inner to outer, + False to use `nparts` for split from outer to inner + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateFollowFusedSplit(self, stage_id, it, src_step_ids, + level, factor_or_nparts) + return state + + def fuse(self, stage_id, iters): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + iters : List[Iterator] + The target Iterators to be fused + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateFuse(self, stage_id, iters) + return state + + def vectorize(self, stage_id, it): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + The target Iterator to be vectorized + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateVectorize(self, stage_id, it) + return state + + def parallel(self, stage_id, it): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + The target Iterator to be paralleled + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateParallel(self, stage_id, it) + return state + + def unroll(self, stage_id, it, max_unroll=-1): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + The target Iterator to be unrolled + max_unroll : Int + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateUnroll(self, stage_id, it, max_unroll) + return state + + def bind_thread(self, stage_id, it, thread_type): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + The target Iterator to be vectorized + thread_type : ... + Supported type: kVThread, kBlockX, kThreadX, kThreadY + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateBindThread(self, stage_id, it, thread_type) + return state + + def compute_at(self, stage_id, target_stage_id, target_iter): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + target_stage_id : Int + The index of compute at target stage + target_iter : Iterator + The target Iterator to be compute at + + Returns + ------- + state : State + The updated state + """ + return _ffi_api.StateComputeAt(self, stage_id, target_stage_id, + target_iter) + + def compute_root(self, stage_id): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + + Returns + ------- + state : State + The updated state + """ + return _ffi_api.StateComputeRoot(self, stage_id) + + def compute_inline(self, stage_id): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + + Returns + ------- + state : State + The updated state + """ + return _ffi_api.StateComputeInline(self, stage_id) + + def pack_for_vec(self, stage_id, target_iter, vec_size): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + target_iter : Iterator + The target Iterator + vec_size : Int + + Returns + ------- + state : State + The updated state + """ + return _ffi_api.StatePackForVec(self, stage_id, target_iter, vec_size) + + def cache_read(self, stage_id, scope_name, reader_stage_ids, task_dag): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + scope_name : Str + reader_stage_ids : List[Int] + task_dag : ComputeDAG + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateCacheRead(self, stage_id, scope_name, + reader_stage_ids, task_dag) + return state + + def cache_write(self, stage_id, scope_name, task_dag): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + scope_name : Str + task_dag : ComputeDAG + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateCacheWrite(self, stage_id, scope_name, task_dag) + return state + + def pragma(self, stage_id, it, pragma_type): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + The target Iterator + pragma_type : Str + + Returns + ------- + state : State + The updated state + """ + return _ffi_api.StatePragma(self, stage_id, it, pragma_type) + + def rfactor(self, stage_id, it, factor_iter_id, task_dag): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + factor_iter_id : Int + task_dag : ComputeDAG + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateRfactor(self, stage_id, it, factor_iter_id, + task_dag) + return state + + def storage_align(self, stage_id, it, factor, offset): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + factor : Int + offset : Int + + Returns + ------- + state : State + The updated state + """ + return _ffi_api.StateStorageAlign(self, stage_id, it, factor, offset) diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index feaefe9f8e9f..1e33068e4965 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -1166,6 +1166,8 @@ std::pair > ComputeDAG::ReplaySteps( return std::make_pair(schedule, operator->()->tensors); } +TVM_REGISTER_GLOBAL("ansor.ComputeDAG") +.set_body_typed([](Array tensors) { return ComputeDAGNode::make(tensors); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index f01899c4c793..ebea5a1e472a 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -3,11 +3,13 @@ */ #include "loop_state.h" #include +#include #include "utils.h" namespace tvm { namespace ansor { +TVM_REGISTER_OBJECT_TYPE(StageNode); TVM_REGISTER_NODE_TYPE(StateNode); Stage StageNode::make(te::Operation op) { @@ -65,6 +67,16 @@ Stage StageNode::make(te::Operation op, StageType op_type, return Stage(node); } +TVM_REGISTER_GLOBAL("ansor.StageGetIterator") + .set_body_typed([](const Stage& stage, int index) { + return stage->iters[index]; + }); + +TVM_REGISTER_GLOBAL("ansor.StageGetIterators") + .set_body_typed([](const Stage& stage) { + return Array(stage->iters); + }); + State StateNode::make_empty_state() { auto node = make_object(); node->attach_map = AttachMapNode::make(); @@ -873,6 +885,143 @@ std::string State::ToStr(bool delete_trivial_loop) const { return os.str(); } +TVM_REGISTER_GLOBAL("ansor.StateGetStage") + .set_body_typed([](const State& state, int index) { + return state->stages[index]; + }); + +TVM_REGISTER_GLOBAL("ansor.StateGetTransformStepsSize") + .set_body_typed([](const State& state) { + return static_cast(state->transform_steps.size()); + }); + +TVM_REGISTER_GLOBAL("ansor.StateReorder") + .set_body_typed([](State state, int stage_id, + const Array& order) { + std::vector ord; + for (const auto& i : order) { + ord.push_back(i); + } + state.reorder(stage_id, ord); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateSplit") + .set_body_typed([](State state, int stage_id, const Iterator& it, + const Array& lengths, + bool inner_to_outer) { + std::vector len; + for (const auto& i : lengths) { + len.push_back(i); + } + state.split(stage_id, it, len, inner_to_outer); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateFollowSplit") + .set_body_typed([](State state, int stage_id, const Iterator& it, + int src_step_id, int n_split) { + state.follow_split(stage_id, it, src_step_id, n_split); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateFollowFusedSplit") + .set_body_typed([](State state, int stage_id, const Iterator& it, + const Array& src_step_ids, int level, + bool factor_or_nparts) { + std::vector array_src_step_ids; + for (const auto& i : src_step_ids) { + array_src_step_ids.push_back(i->value); + } + state.follow_fused_split(stage_id, it, array_src_step_ids, level, + factor_or_nparts); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateFuse") + .set_body_typed([](State state, int stage_id, + const Array& iters) { + std::vector its; + for (const auto& i : iters) { + its.push_back(i); + } + state.fuse(stage_id, its); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateVectorize") + .set_body_typed([](State state, int stage_id, + const Iterator& it) { + state.vectorize(stage_id, it); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateParallel") + .set_body_typed([](State state, int stage_id, + const Iterator& it) { + state.parallel(stage_id, it); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateUnroll") + .set_body_typed([](State state, int stage_id, + const Iterator& it, int max_unroll) { + state.unroll(stage_id, it, max_unroll); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateBindThread") + .set_body_typed([](State state, int stage_id, + const Iterator& it, int thread_type) { + state.bind_thread(stage_id, it, IteratorAnnotation(thread_type)); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateComputeAt") + .set_body_typed([](State state, int stage_id, int target_stage_id, + const Iterator& target_iter) { + state.compute_at(stage_id, target_stage_id, target_iter); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateComputeRoot") + .set_body_typed([](State state, int stage_id) { + state.compute_root(stage_id); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateComputeInline") + .set_body_typed([](State state, int stage_id) { + state.compute_inline(stage_id); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StatePackForVec") + .set_body_typed([](State state, int stage_id, + const Iterator& target_iter, int vec_size) { + state.pack_for_vec(stage_id, target_iter, vec_size); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateCacheRead") + .set_body_typed([](State state, int stage_id, const std::string& scope_name, + const Array& reader_stage_ids, + const ComputeDAG& task_dag) { + std::vector array_reader_stage_ids; + for (const auto& i : reader_stage_ids) { + array_reader_stage_ids.push_back(i->value); + } + state.cache_read(stage_id, scope_name, array_reader_stage_ids, task_dag); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateCacheWrite") + .set_body_typed([](State state, int stage_id, const std::string& scope_name, + const ComputeDAG& task_dag) { + state.cache_write(stage_id, scope_name, task_dag); + return state; + }); + void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id) { AttachMapNode* pnode = CopyOnWrite(); diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index 8cd8233ae9be..5f4a6a8dcef9 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -8,6 +8,7 @@ namespace tvm { namespace ansor { +TVM_REGISTER_NODE_TYPE(IteratorNode); TVM_REGISTER_OBJECT_TYPE(StepNode); /********** Reorder **********/ diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index 9b430be99bd3..627ce02b60e1 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -69,6 +69,11 @@ class IteratorNode : public Object { IteratorType iter_type, IteratorAnnotation annotation, const std::vector* ori_iters = nullptr); + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("range", &range); + } + static constexpr const char *_type_key = "ansor.Iterator"; TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); }; diff --git a/tests/cpp/ansor_test.cc b/tests/cpp/ansor_test.cc index bbcef05f31fc..75a6cc00b802 100644 --- a/tests/cpp/ansor_test.cc +++ b/tests/cpp/ansor_test.cc @@ -19,13 +19,15 @@ #include #include -#include -#include #include -#include "../../src/ansor/loop_state.h" -#include "../../src/ansor/serialization.h" +#include + +#include + #include "../../src/ansor/feature.h" +#include "../../src/ansor/loop_state.h" #include "../../src/ansor/search_policy/meta_tile_rewrite_policy.h" +#include "../../src/ansor/serialization.h" tvm::Array matmul_func(int n, int m, int k) { using namespace tvm; @@ -35,16 +37,17 @@ tvm::Array matmul_func(int n, int m, int k) { Tensor B = placeholder({k, m}, DataType::Float(32), "B"); IterVar K = IterVarNode::make({0, k}, Var("k"), kCommReduce); const auto& C = compute( - {n, m}, - [&](Var i, Var j) { return tvm::sum(A[i][K] * B[K][j], {K}); }, + {n, m}, [&](Var i, Var j) { return tvm::sum(A[i][K] * B[K][j], {K}); }, "C"); return {A, B, C}; } tvm::Array conv2d_nchw_bn_relu_func(int N, int H, int W, - int CI, int CO, int kernel_size, int strides, int padding, - int dilation = 1) { + int CI, int CO, + int kernel_size, + int strides, int padding, + int dilation = 1) { using namespace tvm; using namespace tvm::te; @@ -58,27 +61,27 @@ tvm::Array conv2d_nchw_bn_relu_func(int N, int H, int W, int OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) / strides + 1; int OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) / strides + 1; - const auto& conv = topi::conv2d_nchw(data, kernel, padding, padding, strides, - strides); + const auto& conv = + topi::conv2d_nchw(data, kernel, padding, padding, strides, strides); CHECK(conv->shape[2].as()->value == OH); CHECK(conv->shape[3].as()->value == OW); const auto& bias_add = compute( {N, CO, OH, OW}, [&](Var i, Var j, Var k, Var l) { - return conv[i][j][k][l] + bias[j][0][0]; + return conv[i][j][k][l] + bias[j][0][0]; }, "Bias_add"); const auto& bn_mul = compute( {N, CO, OH, OW}, [&](Var i, Var j, Var k, Var l) { - return bias_add[i][j][k][l] * bn_scale[j][0][0]; + return bias_add[i][j][k][l] * bn_scale[j][0][0]; }, "Bn_mul"); const auto& bn_add = compute( {N, CO, OH, OW}, [&](Var i, Var j, Var k, Var l) { - return bn_mul[i][j][k][l] + bn_offset[j][0][0]; + return bn_mul[i][j][k][l] + bn_offset[j][0][0]; }, "Bn_add"); const auto& out = topi::relu(bn_add); @@ -109,20 +112,22 @@ TEST(ComputeDAG, GetProducersConsumers) { std::unordered_set set; { std::vector> consumer_list = { - {data, padding}, {padding, conv}, {kernel, conv}, {conv, bias_add}, - {bias, bias_add}, {bias_add, bn_mul}, {bn_scale, bn_mul}, - {bn_mul, bn_add}, {bn_offset, bn_add}, {bn_add, relu} - }; + {data, padding}, {padding, conv}, {kernel, conv}, + {conv, bias_add}, {bias, bias_add}, {bias_add, bn_mul}, + {bn_scale, bn_mul}, {bn_mul, bn_add}, {bn_offset, bn_add}, + {bn_add, relu}}; for (const auto& pair : consumer_list) { dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op, &set); CHECK_EQ(set.size(), 1); CHECK_EQ((*set.begin()), s0->stages[pair.second]->op); } std::vector>> producer_list = { - {padding, {data}}, {conv, {padding, kernel}}, {bias_add, {conv, bias}}, - {bn_mul, {bias_add, bn_scale}}, {bn_add, {bn_mul, bn_offset}}, - {relu, {bn_add}} - }; + {padding, {data}}, + {conv, {padding, kernel}}, + {bias_add, {conv, bias}}, + {bn_mul, {bias_add, bn_scale}}, + {bn_add, {bn_mul, bn_offset}}, + {relu, {bn_add}}}; for (const auto& pair : producer_list) { dag->access_analyzer.GetProducers(s0, s0->stages[pair.first]->op, &set); CHECK_EQ(set.size(), pair.second.size()); @@ -138,18 +143,19 @@ TEST(ComputeDAG, GetProducersConsumers) { s0.compute_inline(padding); { std::vector> consumer_list = { - {data, conv}, {kernel, conv}, {conv, relu} - }; + {data, conv}, {kernel, conv}, {conv, relu}}; for (const auto& pair : consumer_list) { dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op, &set); CHECK_EQ(set.size(), 1); CHECK_EQ((*set.begin()), s0->stages[pair.second]->op); } std::vector>> producer_list = { - {padding, {data}}, {conv, {padding, kernel}}, {bias_add, {conv, bias}}, - {bn_mul, {bias_add, bn_scale}}, {bn_add, {bn_mul, bn_offset}}, - {relu, {bn_add}} - }; + {padding, {data}}, + {conv, {padding, kernel}}, + {bias_add, {conv, bias}}, + {bn_mul, {bias_add, bn_scale}}, + {bn_add, {bn_mul, bn_offset}}, + {relu, {bn_add}}}; for (const auto& pair : producer_list) { dag->access_analyzer.GetProducers(s0, s0->stages[pair.first]->op, &set); CHECK_EQ(set.size(), pair.second.size()); @@ -170,15 +176,19 @@ TEST(ComputeDAG, InferBoundSerialization) { C++; const auto& its0 = s0.split(C, s0->stages[C]->iters[0], {4, 8, 8}); const auto& its1 = s0.split(C, s0->stages[C]->iters[4], {8, 4, 4}); - s0.reorder(C, {its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], - its0[3], its1[3]}); + s0.reorder(C, {its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], + its1[3]}); s0.compute_at(C_global, C, s0->stages[C]->iters[3]); s0.split(C_global, s0->stages[C_global]->iters[2], {16}); int B_global = s0.cache_read(B, "global", {C_global}, dag); - C++; C_global++; + C++; + C_global++; s0.compute_at(B_global, C_global, s0->stages[C_global]->iters[0]); int A_global = s0.cache_read(A, "global", {C_global}, dag); - B++; B_global++; C++; C_global++; + B++; + B_global++; + C++; + C_global++; s0.compute_at(A_global, C_global, s0->stages[C_global]->iters[2]); const auto& s1 = dag.InferBound(s0); @@ -186,23 +196,26 @@ TEST(ComputeDAG, InferBoundSerialization) { dag.InferBound(&s2); const auto& s3 = dag.ReplayAndInferBound(s0->transform_steps); - CHECK_EQ(s1->stages[B_global]->iters[0]->range->extent.as()->value, - 512); - CHECK_EQ(s1->stages[B_global]->iters[1]->range->extent.as()->value, - 16); - CHECK_EQ(s1->stages[A_global]->iters[0]->range->extent.as()->value, - 1); - CHECK_EQ(s1->stages[A_global]->iters[1]->range->extent.as()->value, - 16); - CHECK_EQ(s1->stages[C_global]->iters[0]->range->extent.as()->value, - 64); + CHECK_EQ( + s1->stages[B_global]->iters[0]->range->extent.as()->value, + 512); + CHECK_EQ( + s1->stages[B_global]->iters[1]->range->extent.as()->value, + 16); + CHECK_EQ( + s1->stages[A_global]->iters[0]->range->extent.as()->value, 1); + CHECK_EQ( + s1->stages[A_global]->iters[1]->range->extent.as()->value, + 16); + CHECK_EQ( + s1->stages[C_global]->iters[0]->range->extent.as()->value, + 64); CHECK(std::equal_to()(s1, s2[0])); CHECK(std::equal_to()(s1, s3)); const auto& minp0 = MeasureInputNode::make( SearchTaskNode::make(dag, "test", tvm::target::llvm(), - tvm::target::llvm(), - HardwareParams()), + tvm::target::llvm(), HardwareParams()), s0); const auto& mres0 = MeasureResultNode::make({0.1}, 0, "", 0.1, 0.1); std::stringstream ss; @@ -242,7 +255,8 @@ TEST(Step, SplitFuseReorder) { CHECK_EQ(s0->stages[2]->iters[2]->range->extent.as()->value, 512); s0.fuse(2, {tio, tjo}); - CHECK_EQ(s0->stages[2]->iters[0]->range->extent.as()->value, 2048); + CHECK_EQ(s0->stages[2]->iters[0]->range->extent.as()->value, + 2048); s1.split(2, ti, {8, 2}); s1.split(2, tj, {32, 8}, false); @@ -271,10 +285,12 @@ TEST(Step, ComputeAtRootInline) { s0.compute_inline(bn_mul); s0.compute_inline(bias_add); s0.compute_at(conv, relu, s0->stages[relu]->iters[2]); - const auto& conv_stage_attach = s0->attach_map->stage_to_attach_iter.find(conv); + const auto& conv_stage_attach = + s0->attach_map->stage_to_attach_iter.find(conv); std::pair iterkey(relu, 2); CHECK(conv_stage_attach->second == iterkey); - const auto& conv_iter_attach = s0->attach_map->iter_to_attached_stages.find(iterkey); + const auto& conv_iter_attach = + s0->attach_map->iter_to_attached_stages.find(iterkey); CHECK_EQ(conv_iter_attach->second.size(), 1); CHECK_EQ(conv_iter_attach->second[0], conv); std::stringstream ss; @@ -335,25 +351,28 @@ TEST(Step, CacheReadWrite) { int N = 4, H = 7, W = 7, CO = 512, CI = 512, KH = 3, KW = 3, stride = 1; int padding = 1; Tensor data = placeholder({N, CI, H, W}, DataType::Float(32), "Data"); - Tensor kernel_data = placeholder({CO, CI, KH, KW}, DataType::Float(32), - "kernel_data"); - const auto& k_split = compute(kernel_data->shape, + Tensor kernel_data = + placeholder({CO, CI, KH, KW}, DataType::Float(32), "Kernel_data"); + const auto& k_split = compute( + kernel_data->shape, [&](const Array& i) { - return Array({kernel_data[i[0]][i[1]][i[2]][i[3]] + 1, - div(kernel_data[i[0]][i[1]][i[2]][i[3]], 2)}); + return Array({kernel_data[i[0]][i[1]][i[2]][i[3]] + 1, + div(kernel_data[i[0]][i[1]][i[2]][i[3]], 2)}); }, "Kernel_split"); - const auto& kernel = compute(kernel_data->shape, + const auto& kernel = compute( + kernel_data->shape, [&](Var i, Var j, Var k, Var l) { - return (k_split[0])[i][j][k][l] + (k_split[1])[i][j][k][l]; + return (k_split[0])[i][j][k][l] + (k_split[1])[i][j][k][l]; }, "Kernel"); - const auto& conv = topi::conv2d_nchw(data, kernel, padding, padding, stride, - stride); + const auto& conv = + topi::conv2d_nchw(data, kernel, padding, padding, stride, stride); const auto& relu = topi::relu(conv); - const auto& out = compute(relu->shape, + const auto& out = compute( + relu->shape, [&](Var i, Var j, Var k, Var l) { - return data[i][j][k][l] + relu[i][j][k][l]; + return data[i][j][k][l] + relu[i][j][k][l]; }, "Add"); return {data, kernel_data, out}; @@ -372,15 +391,20 @@ TEST(Step, CacheReadWrite) { // 1: simple cache_write with compute_at int conv_global = s0.cache_write(conv, "global", dag); - conv++; relu++; add++; + conv++; + relu++; + add++; s0.compute_at(conv_global, conv, s0->stages[conv]->iters[3]); // 2: simple cache_read with compute_at int kernel_global = s0.cache_read(kernel, "global", {conv_global}, dag); - conv_global++; conv++; relu++; add++; + conv_global++; + conv++; + relu++; + add++; s0.compute_at(kernel_global, conv_global, s0->stages[conv_global]->iters[4]); std::stringstream ss; - ss << "Placeholder: Data, kernel_data\n" + ss << "Placeholder: Data, Kernel_data\n" << "for ax0 (0,4)\n" << " for ax1 (0,512)\n" << " for ax2 (0,9)\n" @@ -425,25 +449,45 @@ TEST(Step, CacheReadWrite) { // 3: two level cache_read with compute_at // preparing for GPU's shared memory & local memory int pad_temp_global = s0.cache_read(pad_temp, "global", {conv_global}, dag); - kernel_data++; kernel_split++; kernel++; kernel_global++; - conv_global++; conv++; relu++; add++; - int pad_temp_shared = s0.cache_read(pad_temp_global, "shared", {conv_global}, - dag); - kernel_data++; kernel_split++; kernel++; kernel_global++; - conv_global++; conv++; relu++; add++; + kernel_data++; + kernel_split++; + kernel++; + kernel_global++; + conv_global++; + conv++; + relu++; + add++; + int pad_temp_shared = + s0.cache_read(pad_temp_global, "shared", {conv_global}, dag); + kernel_data++; + kernel_split++; + kernel++; + kernel_global++; + conv_global++; + conv++; + relu++; + add++; s0.compute_at(pad_temp_global, conv_global, s0->stages[conv_global]->iters[2]); s0.compute_at(pad_temp_shared, conv_global, s0->stages[conv_global]->iters[4]); // 4: cache_read with multi readers - // This stage cannot be compute at to its consumer + // This stage cannot be compute at to its consumer s0.cache_read(data, "global", {pad_temp, add}, dag); - pad_temp++; pad_temp_global++; pad_temp_shared++; - kernel_data++; kernel_split++; kernel++; kernel_global++; - conv_global++; conv++; relu++; add++; + pad_temp++; + pad_temp_global++; + pad_temp_shared++; + kernel_data++; + kernel_split++; + kernel++; + kernel_global++; + conv_global++; + conv++; + relu++; + add++; ss.str(std::string()); - ss << "Placeholder: Data, kernel_data\n" + ss << "Placeholder: Data, Kernel_data\n" << "for ax0 (0,4)\n" << " for ax1 (0,512)\n" << " for ax2 (0,7)\n" @@ -517,7 +561,7 @@ TEST(Step, CacheReadWrite) { // To be fixed in the future s0.cache_write(kernel_split, "global", dag); ss.str(std::string()); - ss << "Placeholder: Data, kernel_data\n" + ss << "Placeholder: Data, Kernel_data\n" << "for ax0 (0,4)\n" << " for ax1 (0,512)\n" << " for ax2 (0,7)\n" @@ -598,8 +642,8 @@ TEST(Step, FollowSplitFollowFusedSplit) { // FollowSplitStep currently only support `inner_to_outer = true` const auto& its0 = s0.split(C, s0->stages[C]->iters[0], {4, 2, 8, 4}, true); int split_step0 = s0->transform_steps.size() - 1; - // const auto& its1 = s0.split(C, s0->stages[C]->iters[5], {4, 2, 8, 4}, false); - // int split_step1 = s0->transform_steps.size() - 1; + // const auto& its1 = s0.split(C, s0->stages[C]->iters[5], {4, 2, 8, 4}, + // false); int split_step1 = s0->transform_steps.size() - 1; for (int level = 1; level <= 5; level++) { State tmp = s0; tmp.follow_split(C_global, s0->stages[C_global]->iters[0], split_step0, @@ -610,7 +654,7 @@ TEST(Step, FollowSplitFollowFusedSplit) { const auto& stage_C_global = tmp->stages[C_global]; for (int i = 0; i < level; i++) { CHECK_EQ(stage_C->iters[i]->range->extent.as()->value, - stage_C_global->iters[i]->range->extent.as()->value); + stage_C_global->iters[i]->range->extent.as()->value); } // for (int i = 0; i < level; i++) { // CHECK(stage_C->iters[i+5]->range->extent.as()->value == @@ -627,7 +671,7 @@ TEST(Step, FollowSplitFollowFusedSplit) { } s0.reorder(C, its); for (int i = 0; i < 5; i++) { - s0.fuse(C, {s0->stages[C]->iters[i], s0->stages[C]->iters[i+1]}); + s0.fuse(C, {s0->stages[C]->iters[i], s0->stages[C]->iters[i + 1]}); } for (int level = 0; level < 4; level++) { State tmp = s0; @@ -635,8 +679,8 @@ TEST(Step, FollowSplitFollowFusedSplit) { {split_step0, split_step1}, level, false); const auto& stage_C = tmp->stages[C]; const auto& stage_C_global = tmp->stages[C_global]; - CHECK_EQ(stage_C->iters[level+1]->range->extent.as()->value, - stage_C_global->iters[0]->range->extent.as()->value); + CHECK_EQ(stage_C->iters[level + 1]->range->extent.as()->value, + stage_C_global->iters[0]->range->extent.as()->value); } for (int level = 0; level < 4; level++) { State tmp = s0; @@ -644,8 +688,8 @@ TEST(Step, FollowSplitFollowFusedSplit) { {split_step0, split_step1}, level, true); const auto& stage_C = tmp->stages[C]; const auto& stage_C_global = tmp->stages[C_global]; - CHECK_EQ(stage_C->iters[level+1]->range->extent.as()->value, - stage_C_global->iters[1]->range->extent.as()->value); + CHECK_EQ(stage_C->iters[level + 1]->range->extent.as()->value, + stage_C_global->iters[1]->range->extent.as()->value); } } @@ -676,10 +720,10 @@ TEST(Feature, ExtractionMatmul) { std::vector> features; std::vector feature_names; GetPerStmtFeatureName(max_n_bufs, &feature_names); - GetPerStmtFeaturesFromStates({s0}, + GetPerStmtFeaturesFromStates( + {s0}, SearchTaskNode::make(dag, "test", tvm::target::llvm(), - tvm::target::llvm(), - HardwareParams()), + tvm::target::llvm(), HardwareParams()), max_n_bufs, 0, &features); int num_states = 1; CHECK_EQ(feature_names.size(), (features[0].size() - 1) / num_states); @@ -704,7 +748,7 @@ class MetaTileRewritePolicyNodeTest { policy->SynthesizeMetaStructure(meta_structures); } void SampleInitPopulation(const std::vector& meta_structures, - int out_size, std::vector* out_states) { + int out_size, std::vector* out_states) { policy->SampleInitPopulation(meta_structures, out_size, out_states); } tvm::runtime::ObjectPtr policy; diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py new file mode 100644 index 000000000000..4782f9130cea --- /dev/null +++ b/tests/python/unittest/test_ansor_common.py @@ -0,0 +1,475 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te +from tvm import ansor +import topi + + +def matmul_nkkm(N, M, K): + A = te.placeholder((N, K), name='A') + B = te.placeholder((K, M), name='B') + k = te.reduce_axis((0, K), name='k') + C = te.compute((N, M), lambda i, j: te.sum( + A[i][k] * B[k][j], axis=[k]), name='C') + + return [A, B, C] + + +def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1): + data = te.placeholder((N, CI, H, W), name='Data') + kernel = te.placeholder((CO, CI, kernel_size, kernel_size), name='Kernel') + bias = te.placeholder((CO, 1, 1), name='Bias') + bn_scale = te.placeholder((CO, 1, 1), name='Bn_scale') + bn_offset = te.placeholder((CO, 1, 1), name='Bn_offset') + + OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + + conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation) + conv = te.compute((N, CO, OH, OW), + lambda i, j, k, l: conv[i, j, k, l] + bias[j, 0, 0], + name='Bias_add') + conv = te.compute((N, CO, OH, OW), + lambda i, j, k, l: conv[i, j, k, l] * bn_scale[j, 0, 0], + name='Bn_mul') + conv = te.compute((N, CO, OH, OW), + lambda i, j, k, l: conv[i, j, k, l] + bn_offset[j, 0, 0], + name='Bn_add') + out = topi.nn.relu(conv) + + return [data, kernel, bias, bn_offset, bn_scale, out] + + +def test_compute_dag_basic(): + dag = ansor.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 2, 3)) + + print(dag) + print(dag.access_analyzer) + print(dag.get_init_state()) + + +def test_state_split_fuse_reorder(): + dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + s0 = dag.get_init_state() + s1 = s0 + ti = s0.stage(2).iterator(0) + tj = s0.stage(2).iterator(1) + tk = s0.stage(2).iterator(2) + + assert ti.range.extent == 512 + + s0 = s0.split(2, ti, [16]) + assert s0.stage(2).iterator(0).range.extent == 32 + assert s0.stage(2).iterator(1).range.extent == 16 + tio = s0.stage(2).iterator(0) + tii = s0.stage(2).iterator(1) + + s0 = s0.split(2, tj, [8]) + assert s0.stage(2).iterator(2).range.extent == 64 + assert s0.stage(2).iterator(3).range.extent == 8 + tjo = s0.stage(2).iterator(2) + tji = s0.stage(2).iterator(3) + + s0 = s0.reorder(2, [tio, tjo, tk, tji, tii]) + assert s0.stage(2).iterator(2).range.extent == 512 + + s0 = s0.fuse(2, [tio, tjo]) + assert s0.stage(2).iterator(0).range.extent == 2048 + + s1 = s1.split(2, ti, [8, 2]) + s1 = s1.split(2, tj, [32, 8], False) + assert s1.stage(2).iterator(0).range.extent == 32 + assert s1.stage(2).iterator(1).range.extent == 8 + assert s1.stage(2).iterator(2).range.extent == 2 + assert s1.stage(2).iterator(3).range.extent == 32 + assert s1.stage(2).iterator(4).range.extent == 8 + assert s1.stage(2).iterator(5).range.extent == 2 + + +def test_state_compute_at_root_inline(): + dag = ansor.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 2, 3)) + + # data, padding, kernel = 0, 1, 2 + conv = 3 + # bias = 4 + bias_add = 5 + # bn_scale = 6 + bn_mul = 7 + # bn_offset = 8 + bn_add, relu = 9, 10 + + s0 = dag.get_init_state() + s0 = s0.compute_inline(bn_add) + s0 = s0.compute_inline(bn_mul) + s0 = s0.compute_inline(bias_add) + s0 = s0.compute_at(conv, relu, s0.stage(relu).iterator(2)) + assert str(s0) == \ + "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ + "for i1 (0,3)\n" + \ + " for i2 (0,230)\n" + \ + " for i3 (0,230)\n" + \ + " pad_temp = ...\n" + \ + "for i1 (0,64)\n" + \ + " for i2 (0,112)\n" + \ + " for nn (None)\n" + \ + " for ff (None)\n" + \ + " for yy (None)\n" + \ + " for xx (None)\n" + \ + " for rc (None)\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute = ...\n" + \ + " for i3 (0,112)\n" + \ + " compute = ...\n" + + s0 = s0.compute_root(conv) + s0 = s0.compute_root(bn_mul) + assert str(s0) == \ + "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ + "for i1 (0,3)\n" + \ + " for i2 (0,230)\n" + \ + " for i3 (0,230)\n" + \ + " pad_temp = ...\n" + \ + "for nn (None)\n" + \ + " for ff (None)\n" + \ + " for yy (None)\n" + \ + " for xx (None)\n" + \ + " for rc (None)\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute = ...\n" + \ + "for i (None)\n" + \ + " for j (None)\n" + \ + " for k (None)\n" + \ + " for l (None)\n" + \ + " Bn_mul = ...\n" + \ + "for i1 (0,64)\n" + \ + " for i2 (0,112)\n" + \ + " for i3 (0,112)\n" + \ + " compute = ...\n" + + +def test_state_cache_read_write(): + N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, ( + 1, 1), (1, 1) + + data = te.placeholder((N, CI, H, W), name='Data') + kernel_data = te.placeholder((CO, CI, KH, KW), name='Kernel_data') + k0, k1 = te.compute(kernel_data.shape, + lambda *i: (kernel_data(*i)+1, kernel_data(*i)/2), + name='Kernel_split') + kernel = te.compute(kernel_data.shape, + lambda *i: k0(*i) + k1(*i), + name='Kernel') + conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation=1) + relu = topi.nn.relu(conv) + out = topi.add(data, relu) + + dag = ansor.ComputeDAG([data, kernel_data, out]) + data, pad_temp, kernel_data, kernel_split, kernel, conv, relu, add = 0, 1, 2, 3, 4, 5, 6, 7 + + # 0: init state + s0 = dag.get_init_state() + ori_its = s0.stage(add).iterators() + s0 = s0.split(add, s0.stage(add).iterator(0), [2]) + s0 = s0.reorder(add, [s0.stage(add).iterator(0), ori_its[1], + s0.stage(add).iterator(1), ori_its[2], ori_its[3]]) + s0 = s0.compute_inline(relu) + + # 1: simple cache_write with compute_at + s0 = s0.cache_write(conv, "global", dag) + conv_global = conv + conv += 1 + relu += 1 + add += 1 + s0 = s0.compute_at(conv_global, conv, s0.stage(conv).iterator(3)) + + # 2: simple cache_read with compute_at + s0 = s0.cache_read(kernel, "global", [conv_global], dag) + kernel_global = kernel + 1 + conv_global += 1 + conv += 1 + relu += 1 + add += 1 + s0 = s0.compute_at(kernel_global, conv_global, + s0.stage(conv_global).iterator(4)) + assert str(s0) == \ + "Placeholder: Data, Kernel_data\n" + \ + "for i0 (0,4)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,9)\n" + \ + " for i3 (0,9)\n" + \ + " pad_temp = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel_split = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel = ...\n" + \ + "for nn (0,4)\n" + \ + " for ff (0,512)\n" + \ + " for yy (0,7)\n" + \ + " for xx (0,7)\n" + \ + " for nn_c (None)\n" + \ + " for ff_c (None)\n" + \ + " for yy_c (None)\n" + \ + " for xx_c (None)\n" + \ + " for rc (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " Kernel.global = ...\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute.global = ...\n" + \ + " compute = ...\n" + \ + "for ax0.0 (0,2)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax0.1 (0,2)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " T_add = ...\n" + + # 3: two level cache_read with compute_at + # preparing for GPU's shared memory & local memory + s0 = s0.cache_read(pad_temp, "global", [conv_global], dag) + pad_temp_global = pad_temp + 1 + kernel_data += 1 + kernel_split += 1 + kernel += 1 + kernel_global += 1 + conv_global += 1 + conv += 1 + relu += 1 + add += 1 + s0 = s0.cache_read(pad_temp_global, "shared", [conv_global], dag) + pad_temp_shared = pad_temp_global + 1 + kernel_data += 1 + kernel_split += 1 + kernel += 1 + kernel_global += 1 + conv_global += 1 + conv += 1 + relu += 1 + add += 1 + s0 = s0.compute_at(pad_temp_global, conv_global, + s0.stage(conv_global).iterator(2)) + s0 = s0.compute_at(pad_temp_shared, conv_global, + s0.stage(conv_global).iterator(4)) + + # 4: cache_read with multi readers + # This stage cannot be compute at to its consumer + s0 = s0.cache_read(data, "global", [pad_temp, add], dag) + pad_temp += 1 + pad_temp_global += 1 + pad_temp_shared += 1 + kernel_data += 1 + kernel_split += 1 + kernel += 1 + kernel_global += 1 + conv_global += 1 + conv += 1 + relu += 1 + add += 1 + assert str(s0) == \ + "Placeholder: Data, Kernel_data\n" + \ + "for ax0 (0,4)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " Data.global = ...\n" + \ + "for i0 (0,4)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,9)\n" + \ + " for i3 (0,9)\n" + \ + " pad_temp = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel_split = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel = ...\n" + \ + "for nn (0,4)\n" + \ + " for ff (0,512)\n" + \ + " for yy (0,7)\n" + \ + " for xx (0,7)\n" + \ + " for nn_c (None)\n" + \ + " for ff_c (None)\n" + \ + " for yy_c (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " pad_temp.global = ...\n" + \ + " for xx_c (None)\n" + \ + " for rc (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " Kernel.global = ...\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " pad_temp.global.shared = ...\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute.global = ...\n" + \ + " compute = ...\n" + \ + "for ax0.0 (0,2)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax0.1 (0,2)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " T_add = ...\n" + + # 5: cache_write with multi outputs + # See tests/cpp/ansor_test.cc for more information + s0 = s0.cache_write(kernel_split, "global", dag) + assert str(s0) == \ + "Placeholder: Data, Kernel_data\n" + \ + "for ax0 (0,4)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " Data.global = ...\n" + \ + "for i0 (0,4)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,9)\n" + \ + " for i3 (0,9)\n" + \ + " pad_temp = ...\n" + \ + "for i0_c (0,512)\n" + \ + " for i1_c (0,512)\n" + \ + " for i2_c (0,3)\n" + \ + " for i3_c (0,3)\n" + \ + " Kernel_split.global = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel_split = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel_split = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel = ...\n" + \ + "for nn (0,4)\n" + \ + " for ff (0,512)\n" + \ + " for yy (0,7)\n" + \ + " for xx (0,7)\n" + \ + " for nn_c (None)\n" + \ + " for ff_c (None)\n" + \ + " for yy_c (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " pad_temp.global = ...\n" + \ + " for xx_c (None)\n" + \ + " for rc (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " Kernel.global = ...\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " pad_temp.global.shared = ...\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute.global = ...\n" + \ + " compute = ...\n" + \ + "for ax0.0 (0,2)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax0.1 (0,2)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " T_add = ...\n" + + +def test_follow_split_follow_fused_split(): + dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + s0 = dag.get_init_state() + C = 2 + + s0 = s0.cache_write(C, "global", dag) + C_global = C + C += 1 + + s0 = s0.split(C, s0.stage(C).iterator(0), [4, 2, 8, 4], True) + split_step0 = s0.transform_steps_size() - 1 + for level in range(1, 6): + tmp = s0 + tmp = tmp.follow_split(C_global, tmp.stage( + C_global).iterator(0), split_step0, level) + for i in range(0, level): + assert tmp.stage(C).iterator(i).range.extent == \ + tmp.stage(C_global).iterator(i).range.extent + + s0 = s0.split(C, s0.stage(C).iterator(5), [2, 2, 4, 8]) + split_step1 = s0.transform_steps_size() - 1 + its = s0.stage(C).iterators() + s0 = s0.reorder(C, [its[0], its[5], its[1], its[6], its[2], its[7], + its[3], its[8], its[4], its[9]]) + s0 = s0.fuse(C, [s0.stage(C).iterator(0), s0.stage(C).iterator(1)]) + s0 = s0.fuse(C, [s0.stage(C).iterator(1), s0.stage(C).iterator(2)]) + s0 = s0.fuse(C, [s0.stage(C).iterator(2), s0.stage(C).iterator(3)]) + s0 = s0.fuse(C, [s0.stage(C).iterator(3), s0.stage(C).iterator(4)]) + s0 = s0.fuse(C, [s0.stage(C).iterator(4), s0.stage(C).iterator(5)]) + for level in range(0, 4): + tmp = s0 + tmp = tmp.follow_fused_split(C_global, tmp.stage(C_global).iterator(0), + [split_step0, split_step1], level, False) + assert tmp.stage(C).iterator(level+1).range.extent == \ + tmp.stage(C_global).iterator(0).range.extent + for level in range(0, 4): + tmp = s0 + tmp = tmp.follow_fused_split(C_global, tmp.stage(C_global).iterator(0), + [split_step0, split_step1], level, True) + assert tmp.stage(C).iterator(level+1).range.extent == \ + tmp.stage(C_global).iterator(1).range.extent + + +def test_rfactor(): + pass + + +if __name__ == "__main__": + test_compute_dag_basic() + test_state_split_fuse_reorder() + test_state_compute_at_root_inline() + test_state_cache_read_write() + test_follow_split_follow_fused_split() + test_rfactor()