From 543e446886841d5564c58c434b61366864fa49c3 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 4 Oct 2021 06:54:20 +0800 Subject: [PATCH] [CherryPick][Intrinsic] lower_bound and upper_bound for binary search in Sparse TIR. (#483) * upd * upd * fix * upd * upd * upd * upd * upd * fix * upd * upd * upd * upd * upd * upd * upd * codegen-rule * upd * upd * test * upd * fix * two arguments --- include/tvm/tir/builtin.h | 10 +++ python/tvm/script/tir/intrin.py | 10 +++ python/tvm/tir/__init__.py | 2 +- python/tvm/tir/op.py | 56 +++++++++++++++ src/target/source/codegen_cuda.cc | 29 ++++++++ src/target/source/codegen_cuda.h | 2 + .../source/literal/cuda_binary_search.h | 69 +++++++++++++++++++ src/tir/op/builtin.cc | 6 ++ src/tir/op/op.cc | 14 ++++ tests/python/unittest/test_tir_intrin.py | 50 ++++++++++++++ 10 files changed, 247 insertions(+), 1 deletion(-) create mode 100644 src/target/source/literal/cuda_binary_search.h diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 86857a33cdf4..27936f5a8a76 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -494,6 +494,16 @@ TVM_DLL const Op& tvm_warp_shuffle_up(); TVM_DLL const Op& tvm_warp_shuffle_down(); TVM_DLL const Op& tvm_warp_activemask(); +/*! + * \brief Lower bound function for binary search. + */ +TVM_DLL const Op& tvm_lower_bound(); + +/*! + * \brief Upper bound function for binary search. + */ +TVM_DLL const Op& tvm_upper_bound(); + /*! * \brief Initialize the global barrier. * Call this at beginning of kernel that need global barrier. diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py index 4d7fe80b28b1..4c105a142c02 100644 --- a/python/tvm/script/tir/intrin.py +++ b/python/tvm/script/tir/intrin.py @@ -110,6 +110,16 @@ def max_value(dtype, span): return tvm.tir.max_value(dtype, span) +@register +def lower_bound(arr, val, l, r, span): + return tvm.tir.lower_bound(arr, val, l, r, span) + + +@register +def upper_bound(arr, val, l, r, span): + return tvm.tir.upper_bound(arr, val, l, r, span) + + @register def floordiv(x, y, span): return tvm.tir.floordiv(x, y, span) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 44006239acfd..58c63825115f 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -36,7 +36,7 @@ from .function import PrimFunc from .op import call_packed, call_intrin, call_pure_extern, call_extern -from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace +from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace, lower_bound, upper_bound from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz from .op import sin, sinh, asin, asinh from .op import cos, cosh, acos, acosh diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index de3ca5fa8d5b..41714538668d 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -971,6 +971,62 @@ def ldexp(x1, x2): return call_intrin(x1.dtype, "tir.ldexp", x1, x2) # type: ignore +def lower_bound(arr, val, l, r, span=None): + """Return the position to the first element in the arr[l:r] that is no less than val. + + Parameters + ---------- + arr : Var + Pointer to the 1D buffer to apply binary search on. + + val : PrimExpr + Value of the lower bound to search for in the buffer. + + l : PrimExpr + Start position to search for in the buffer. + + r : PrimExpr + End position to search for in the buffer. + + span : Optional[Span] + The location of this expression in the source code. + + Returns + ------- + PrimExpr + The index of element in arr[l:r] that is no less then given value. + """ + return _ffi_api.lower_bound(arr, val, l, r, span) # type: ignore + + +def upper_bound(arr, val, l, r, span=None): + """Return the position the first element in the arr that is greater than val. + + Parameters + ---------- + arr : Var + Pointer to the 1D buffer to apply binary search on. + + val : PrimExpr + Value of the upper bound to search for in the buffer. + + l : PrimExpr + Start position to search for in the buffer. + + r : PrimExpr + End position to search for in the buffer. + + span : Optional[Span] + The location of this expression in the source code. + + Returns + ------- + PrimExpr + The index of element in arr[l:r] that is no less then given value. + """ + return _ffi_api.upper_bound(arr, val, l, r, span) # type: ignore + + def isnan(x, span=None): """Check if input value is Nan. diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 0aad18ffb6f9..49a451c17832 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -33,6 +33,7 @@ #include #include "literal/cuda_half_t.h" +#include "literal/cuda_binary_search.h" namespace tvm { namespace codegen { @@ -132,6 +133,10 @@ std::string CodeGenCUDA::Finish() { decl_stream << "#include \n"; } + if (need_binary_search_) { + decl_stream << _cuda_binary_search_def; + } + decl_stream << "\n#ifdef _WIN32\n"; decl_stream << " using uint = unsigned int;\n"; decl_stream << " using uchar = unsigned char;\n"; @@ -723,6 +728,30 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { this->PrintExpr(op->args[i * 2 + 1], os); os << "]" << ((i < 3) ? ", " : ")"); } + } else if (op->op.same_as(builtin::tvm_lower_bound())) { + need_binary_search_ = true; + os << "__lower_bound("; + ICHECK_EQ(op->args.size(), 4U); + this->PrintExpr(op->args[0], os); + os << ", "; + this->PrintExpr(op->args[1], os); + os << ", "; + this->PrintExpr(op->args[2], os); + os << ", "; + this->PrintExpr(op->args[3], os); + os << ")"; + } else if (op->op.same_as(builtin::tvm_upper_bound())) { + need_binary_search_ = true; + os << "__upper_bound("; + ICHECK_EQ(op->args.size(), 4U); + this->PrintExpr(op->args[0], os); + os << ", "; + this->PrintExpr(op->args[1], os); + os << ", "; + this->PrintExpr(op->args[2], os); + os << ", "; + this->PrintExpr(op->args[3], os); + os << ")"; } else { CodeGenC::VisitExpr_(op, os); } diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index 385b7343c8fd..18ad850e7cd6 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -99,6 +99,8 @@ class CodeGenCUDA final : public CodeGenC { bool need_math_constants_h_{false}; // whether need mma.h bool need_mma_h_{false}; + // whether need binary search + bool need_binary_search_{false}; // Op attribute map OpAttrMap op_need_warp_shuffle_ = Op::GetAttrMap("cuda.need_warp_shuffle"); diff --git a/src/target/source/literal/cuda_binary_search.h b/src/target/source/literal/cuda_binary_search.h new file mode 100644 index 000000000000..becd3e33c0d6 --- /dev/null +++ b/src/target/source/literal/cuda_binary_search.h @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file cuda_binary_search.h + * \brief Binary search function definition for cuda codegen. + */ +#ifndef TVM_TARGET_SOURCE_LITERAL_CUDA_BINARY_SEARCH_H_ +#define TVM_TARGET_SOURCE_LITERAL_CUDA_BINARY_SEARCH_H_ + +static constexpr const char* _cuda_binary_search_def = R"( +template +__forceinline__ __device__ int32_t __lower_bound( + const DType* __restrict__ arr, + DType val, + int32_t l, + int32_t r) { + int32_t low = l - 1, high = r; + /* loop invariant: low < mid < high, arr[low] < val, arr[high] >= val */ + while (low + 1 < high) { + int32_t mid = (low + high) >> 1; + if (arr[mid] < val) { + low = mid; + } else { + high = mid; + } + } + // high = low + 1, arr[low] < val, arr[high] >= val + return high; +} + +template +__forceinline__ __device__ int32_t __upper_bound( + const DType* __restrict__ arr, + DType val, + int32_t l, + int32_t r) { + int32_t low = l - 1, high = r; + /* loop invariant: low < mid < high, arr[low] < val, arr[high] > val */ + while (low + 1 < high) { + int32_t mid = (low + high) >> 1; + if (arr[mid] > val) { + high = mid; + } else { + low = mid; + } + } + // high = low + 1, arr[low] <= val, arr[high] > val + return high; +} +)"; + +#endif // TVM_TARGET_SOURCE_LITERAL_CUDA_BINARY_SEARCH_H_ diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index c593cbf7290c..8166661ecc61 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -222,6 +222,12 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_thread_allreduce) TIR_DEFINE_BUILTIN_FUNC(tvm_load_matrix_sync) .set_attr("TCallEffectKind", Integer(CallEffectKind::kReadState)); +TIR_DEFINE_BUILTIN_FUNC(tvm_lower_bound) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(tvm_upper_bound) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(tvm_mma_sync) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index d08bef2ab91a..0ce984f5eeec 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -804,6 +804,16 @@ PrimExpr nearbyint(PrimExpr x, Span span) { TIR_REGISTER_PURE_UNARY_OP("tir.nearbyint"); +// lower_bound +PrimExpr lower_bound(Var arr, PrimExpr val, PrimExpr l, PrimExpr r, Span span) { + return tir::Call({kDLInt, 32, 1}, builtin::tvm_lower_bound(), {arr, val, l, r}, span); +} + +// upper_bound +PrimExpr upper_bound(Var arr, PrimExpr val, PrimExpr l, PrimExpr r, Span span) { + return tir::Call({kDLInt, 32, 1}, builtin::tvm_upper_bound(), {arr, val, l, r}, span); +} + // trunc PrimExpr trunc(PrimExpr x, Span span) { if (x.dtype().is_int() || x.dtype().is_uint()) { @@ -918,6 +928,10 @@ TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc); TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast); +TVM_REGISTER_GLOBAL("tir.lower_bound").set_body_typed(tvm::lower_bound); + +TVM_REGISTER_GLOBAL("tir.upper_bound").set_body_typed(tvm::upper_bound); + // operator overloading, smarter than make #define REGISTER_MAKE_BINARY_OP(Node, Func) \ TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \ diff --git a/tests/python/unittest/test_tir_intrin.py b/tests/python/unittest/test_tir_intrin.py index 3e9e7fd33fd9..036db96b1c2c 100644 --- a/tests/python/unittest/test_tir_intrin.py +++ b/tests/python/unittest/test_tir_intrin.py @@ -253,6 +253,55 @@ def test_fma(): assert mod["test_tir_fma"].body.body.value.op.name == "tir.call_llvm_pure_intrin" +@tvm.script.tir +def binary_search(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> None: + n = tir.var('int32') + m = tir.var('int32') + A = tir.match_buffer(a, (n,), dtype='int32') + B = tir.match_buffer(b, (m,), dtype='int32') + C = tir.match_buffer(c, (m,), dtype='int32') + D = tir.match_buffer(d, (m,), dtype='int32') + with tir.block([m], 'search') as [vi]: + tir.reads([A[0:n], B[vi]]) + tir.writes([C[vi], D[vi]]) + C[vi] = tir.lower_bound(A.data, B[vi], 0, n) + D[vi] = tir.upper_bound(A.data, B[vi], 0, n) + + +def test_binary_search(): + sch = tir.Schedule(binary_search) + b = sch.get_block('search') + i, = sch.get_loops(b) + io, ii = sch.split(i, [1, None]) + sch.bind(io, 'threadIdx.x') + sch.bind(ii, 'blockIdx.x') + f = tvm.build(sch.mod['main'], target='cuda') + # print(f.imported_modules[0].get_source()) + + x = np.arange(-128, 128).astype(np.int32) + y = np.random.randint(-200, 200, size=1024).astype(np.int32) + a = np.zeros((1024,)).astype(np.int32) + b = np.zeros((1024,)).astype(np.int32) + + # numpy results + np_a = np.searchsorted(x, y, side='left').astype(np.int32) + np_b = np.searchsorted(x, y, side='right').astype(np.int32) + + # tvm results + dev = tvm.cuda(0) + x_array = tvm.nd.array(x, device=dev) + y_array = tvm.nd.array(y, device=dev) + a_array = tvm.nd.array(a, device=dev) + b_array = tvm.nd.array(b, device=dev) + f(x_array, y_array, a_array, b_array) + tvm_a = a_array.numpy() + tvm_b = b_array.numpy() + + # verify result + tvm.testing.assert_allclose(np_a, tvm_a) + tvm.testing.assert_allclose(np_b, tvm_b) + + if __name__ == "__main__": test_nearbyint() test_unary_intrin() @@ -261,3 +310,4 @@ def test_fma(): test_ldexp() test_clz() test_fma() + test_binary_search()