Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[INTRIN] Add support for floor and ceil #1267

Merged
merged 1 commit into from
Jun 12, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ set(USE_METAL OFF)
# Whether enable Vulkan runtime
set(USE_VULKAN OFF)

# Whether enable OpenGL runtime
set(USE_OPENGL OFF)

# Whether enable RPC runtime
set(USE_RPC ON)

Expand Down
2 changes: 2 additions & 0 deletions include/tvm/ir_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid);
TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(floor);
TVM_DECLARE_INTRIN_UNARY(ceil);

inline Expr pow(Expr x, Expr y) {
return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic);
Expand Down
32 changes: 32 additions & 0 deletions python/tvm/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,38 @@ def sqrt(x):
return call_pure_intrin(x.dtype, "sqrt", x)


def floor(x):
"""Take floor of float input x.

Parameters
----------
x : Expr
Input argument.

Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "floor", x)


def ceil(x):
"""Take ceil of float input x.

Parameters
----------
x : Expr
Input argument.

Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "ceil", x)


def power(x, y):
"""x power y

Expand Down
6 changes: 6 additions & 0 deletions src/codegen/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ struct CUDAShuffle {
}
};

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor")
.set_body(DispatchExtern<CUDAMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil")
.set_body(DispatchExtern<CUDAMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
.set_body(DispatchExtern<CUDAFastMath>);

Expand Down
6 changes: 6 additions & 0 deletions src/codegen/intrin_rule_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ namespace tvm {
namespace codegen {
namespace intrin {

TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp")
.set_body(DispatchExtern<Direct>);

Expand Down
6 changes: 6 additions & 0 deletions src/codegen/intrin_rule_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ namespace tvm {
namespace codegen {
namespace intrin {

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp")
.set_body(DispatchExtern<Direct>);

Expand Down
6 changes: 6 additions & 0 deletions src/codegen/intrin_rule_opengl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ namespace tvm {
namespace codegen {
namespace intrin {

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.floor")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.ceil")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp")
.set_body(DispatchExtern<Direct>);

Expand Down
6 changes: 6 additions & 0 deletions src/codegen/llvm/intrin_rule_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.floor")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.ceil")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
Expand Down
6 changes: 6 additions & 0 deletions src/codegen/llvm/intrin_rule_rocm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) {

namespace llvm {

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp")
.set_body(DispatchExternOCML);

Expand Down
3 changes: 0 additions & 3 deletions src/codegen/spirv/build_vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
* \file build_vulkan.cc
* \brief Build SPIRV block
*/
#if TVM_VULKAN_RUNTIME

// Use libspirv for parsing and validating code.
#include <vulkan/libspirv.h>
#include <dmlc/memory_io.h>
Expand Down Expand Up @@ -92,4 +90,3 @@ TVM_REGISTER_API("codegen.build_vulkan")

} // namespace codegen
} // namespace tvm
#endif // TVM_VULKAN_RUNTIME
5 changes: 0 additions & 5 deletions src/codegen/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
* \file codegen_spirv.cc
* \brief Generate SPIRV block
*/

#if TVM_VULKAN_RUNTIME

#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include "../codegen_common.h"
Expand Down Expand Up @@ -634,5 +631,3 @@ void CodeGenSPIRV::VisitStmt_(const ProducerConsumer* op) {

} // namespace codegen
} // namespace tvm

#endif // TVM_VULKAN_RUNTIME
13 changes: 9 additions & 4 deletions src/codegen/spirv/intrin_rule_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
* Copyright (c) 2017 by Contributors
* \file intrin_rule_spirv.cc
*/
#if TVM_VULKAN_RUNTIME

#include <tvm/packed_func_ext.h>
#include <tvm/ir.h>
#include <vulkan/GLSL.std.450.h>
Expand Down Expand Up @@ -31,6 +29,12 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
call->type, "spirv_glsl450", cargs, ir::Call::PureIntrinsic);
}

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Floor>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.ceil")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Ceil>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>);

Expand All @@ -43,8 +47,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Pow>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.tanh")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Tanh>);

} // namespace spirv
} // namespace codegen
} // namespace tvm

#endif // TVM_VULKAN_RUNTIME
5 changes: 0 additions & 5 deletions src/codegen/spirv/ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
* \file ir_builder.cc
* \brief IRBuilder for SPIRV block
*/

#if TVM_VULKAN_RUNTIME

#include "./ir_builder.h"

namespace tvm {
Expand Down Expand Up @@ -555,5 +552,3 @@ Value IRBuilder::Select(Value cond, Value a, Value b) {
} // namespace spirv
} // namespace codegen
} // namespace tvm

#endif // TVM_VULKAN_RUNTIME
2 changes: 2 additions & 0 deletions topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ TOPI_DECLARE_UNARY_OP(tanh);
TOPI_DECLARE_UNARY_OP(sigmoid);
TOPI_DECLARE_UNARY_OP(sqrt);
TOPI_DECLARE_UNARY_OP(log);
TOPI_DECLARE_UNARY_OP(floor);
TOPI_DECLARE_UNARY_OP(ceil);

/*!
* \brief Creates an operation that returns identity of a given tensor
Expand Down
34 changes: 34 additions & 0 deletions topi/python/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,40 @@ def tanh(x):
return tvm.compute(x.shape, lambda *i: tvm.tanh(x(*i)))


@tvm.tag_scope(tag=tag.ELEMWISE)
def floor(x):
"""Take floor of input x.

Parameters
----------
x : tvm.Tensor
Input argument.

Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.floor(x(*i)))


@tvm.tag_scope(tag=tag.ELEMWISE)
def ceil(x):
"""Take ceil of input x.

Parameters
----------
x : tvm.Tensor
Input argument.

Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.ceil(x(*i)))


@tvm.tag_scope(tag=tag.ELEMWISE)
def log(x):
"""Take logarithm of input x.
Expand Down
18 changes: 10 additions & 8 deletions topi/tests/python/test_topi_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@ def test_ewise():

shape = (20, 3)

def test_apply(func, name, f_numpy):
def test_apply(func, name, f_numpy, low, high):
B = func(A)
assert tuple(B.shape) == tuple(A.shape)
assert B.op.body[0].name == name
a_np = np.random.uniform(low=1e-5, size=shape).astype(A.dtype)
a_np = np.abs(a_np)
a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10
b_np = f_numpy(a_np)

def check_device(device):
Expand All @@ -43,11 +42,14 @@ def check_device(device):
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'llvm']:
check_device(device)

test_apply(topi.exp, "exp", np.exp)
test_apply(topi.tanh, "tanh", np.tanh)
test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x)))
test_apply(topi.log, "log", np.log)
test_apply(topi.sqrt, "sqrt", np.sqrt)

test_apply(topi.floor, "floor", np.floor, -100, 100)
test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
test_apply(topi.exp, "exp", np.exp, -1, 1)
test_apply(topi.tanh, "tanh", np.tanh, -10, 10)
test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x)), -1, 1)
test_apply(topi.log, "log", np.log, 0, 100)
test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100)

if __name__ == "__main__":
test_util()
Expand Down