diff --git a/src/codegen/codegen_opencl.cc b/src/codegen/codegen_opencl.cc index a0b3c2000a80..ae87bf9d9b93 100644 --- a/src/codegen/codegen_opencl.cc +++ b/src/codegen/codegen_opencl.cc @@ -208,6 +208,25 @@ void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOL os << "))"; } +void CodeGenOpenCL::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) + /* Return type of ternary expression is not always same as its sub-expressions, + * add a cast */ + if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { + os << "("; + PrintType(op->args[2].type(), os); + os << ")"; + } + CodeGenC::VisitExpr_(op, os); +} + +void CodeGenOpenCL::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(*) + /* Return type of ternary expression is not always same as its sub-expressions, + * add a cast */ + os << "("; + PrintType(op->true_value.type(), os); + os << ")"; + CodeGenC::VisitExpr_(op, os); +} runtime::Module BuildOpenCL(Array funcs) { using tvm::runtime::Registry; diff --git a/src/codegen/codegen_opencl.h b/src/codegen/codegen_opencl.h index 90569d176a0b..350b6c0f3402 100644 --- a/src/codegen/codegen_opencl.h +++ b/src/codegen/codegen_opencl.h @@ -38,6 +38,8 @@ class CodeGenOpenCL final : public CodeGenC { // overload visitor void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const Call* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const Select* op, std::ostream& os) final; // NOLINT(*) private: // whether enable fp16 and fp64 extension diff --git a/tests/python/unittest/test_codegen_opencl.py b/tests/python/unittest/test_codegen_opencl.py new file mode 100644 index 000000000000..37aaadc5bb23 --- /dev/null +++ b/tests/python/unittest/test_codegen_opencl.py @@ -0,0 +1,55 @@ +import tvm + +target = 'opencl' + +def test_opencl_ternary_expression(): + def check_if_then_else(ctx, n, dtype): + A = tvm.placeholder((n,), name='A', dtype=dtype) + true_value = tvm.const(1, dtype=dtype) + false_value = tvm.const(3, dtype=dtype) + max_lhs = tvm.const(2, dtype=dtype) + max_rhs = tvm.if_then_else(A[0] > 0, true_value, false_value) + C = tvm.compute((n,), lambda i: tvm.max(max_lhs, max_rhs), name='C') + s = tvm.create_schedule(C.op) + s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x")) + fun = tvm.build(s, [A, C], target) + + a = tvm.nd.empty((n,), A.dtype, ctx) + c = tvm.nd.empty((n,), A.dtype, ctx) + # Only need to test compiling here + fun(a, c) + + def check_select(ctx, n, dtype): + A = tvm.placeholder((n,), name='A', dtype=dtype) + true_value = tvm.const(1, dtype=dtype) + false_value = tvm.const(3, dtype=dtype) + max_lhs = tvm.const(2, dtype=dtype) + max_rhs = tvm.expr.Select(A[0] > 0, true_value, false_value) + C = tvm.compute((n,), lambda i: tvm.max(max_lhs, max_rhs), name='C') + s = tvm.create_schedule(C.op) + s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x")) + fun = tvm.build(s, [A, C], target) + + a = tvm.nd.empty((n,), A.dtype, ctx) + c = tvm.nd.empty((n,), A.dtype, ctx) + # Only need to test compiling here + fun(a, c) + + if not tvm.module.enabled(target): + print("skip because opencl is not enabled..") + return + + ctx = tvm.context(target, 0) + + check_if_then_else(ctx, 1, 'int8') + check_if_then_else(ctx, 1, 'uint8') + check_if_then_else(ctx, 1, 'int16') + check_if_then_else(ctx, 1, 'uint16') + check_select(ctx, 1, 'int8') + check_select(ctx, 1, 'uint8') + check_select(ctx, 1, 'int16') + check_select(ctx, 1, 'uint16') + + +if __name__ == "__main__": + test_opencl_ternary_expression()