Skip to content

Commit

Permalink
rename to ndarray_size
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed Jul 16, 2019
1 parent 25de6b3 commit 6018bca
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 12 deletions.
18 changes: 9 additions & 9 deletions src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,34 +280,34 @@ RELAY_REGISTER_OP("shape_of")
.set_attr<FTVMCompute>("FTVMCompute", ShapeOfCompute);


TVM_REGISTER_NODE_TYPE(NumElementsAttrs);
TVM_REGISTER_NODE_TYPE(SizeAttrs);

bool NumElementsRel(const Array<Type>& types,
bool SizeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(num_inputs, 1);
auto tt = types[0].as<TensorTypeNode>();
CHECK(tt != nullptr);
const auto* param = attrs.as<NumElementsAttrs>();
const auto* param = attrs.as<SizeAttrs>();
CHECK(param != nullptr);
reporter->Assign(types[1], TensorTypeNode::make({1}, param->dtype));
return true;
}

Array<Tensor> NumElementsCompute(const Attrs& attrs,
Array<Tensor> SizeCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
CHECK_EQ(inputs.size(), 1);
const auto* param = attrs.as<NumElementsAttrs>();
const auto* param = attrs.as<SizeAttrs>();
CHECK(param != nullptr);
return Array<Tensor>{topi::size(inputs[0], param->dtype)};
}

TVM_REGISTER_API("relay.op.contrib._make.size")
.set_body_typed<Expr(Expr, DataType)>([](Expr data, DataType dtype) {
auto attrs = make_node<NumElementsAttrs>();
auto attrs = make_node<SizeAttrs>();
attrs->dtype = dtype;
static const Op& op = Op::Get("contrib.ndarray_size");
return CallNode::make(op, {data}, Attrs(attrs), {});
Expand All @@ -318,15 +318,15 @@ RELAY_REGISTER_OP("contrib.ndarray_size")
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.NumElementsAttrs")
.set_attrs_type_key("relay.attrs.SizeAttrs")
.add_argument("data", "Tensor", "The input tensor.")
.add_type_rel("NumElements", NumElementsRel)
.add_type_rel("Size", SizeRel)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_support_level(10)
.set_attr<FTVMCompute>("FTVMCompute", NumElementsCompute);
.set_attr<FTVMCompute>("FTVMCompute", SizeCompute);

} // namespace relay
} // namespace tvm
4 changes: 2 additions & 2 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ def test_shape_of():
def test_ndarray_size():
def verify_ndarray_size(shape):
x = relay.var("x", shape=shape)
func = relay.Function([x], relay.op.contrib.num_elements(x))
func = relay.ir_pass.infer_type(func)
func = relay.Function([x], relay.op.contrib.ndarray_size(x))
func = run_infer_type(func)

x_data = np.random.uniform(size=shape).astype("float32")
ref_res = np.size(x_data)
Expand Down
1 change: 0 additions & 1 deletion topi/python/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,4 +496,3 @@ def size(array, dtype="int32"):
The resulting tensor.
"""
return cpp.size(array, dtype)

0 comments on commit 6018bca

Please sign in to comment.