Skip to content

Commit

Permalink
use contrib.num_elements
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed Jun 13, 2019
1 parent a09988e commit 48d69ac
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 14 deletions.
2 changes: 1 addition & 1 deletion include/tvm/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class OperationNode : public FunctionBaseNode {
std::string name;
/*! \brief optional tag of the operation */
std::string tag;
/*! \brief addtitional attributes of the operation*/
/*! \brief additional attributes of the operation*/
Map<std::string, NodeRef> attrs;
/*! \return name of the operation */
const std::string& func_name() const final {
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,10 @@ struct ShapeOfAttrs : public tvm::AttrsNode<ShapeOfAttrs> {
};

/*! \brief Attributes for Size operator */
struct SizeAttrs : public tvm::AttrsNode<SizeAttrs> {
struct NumElementsAttrs : public tvm::AttrsNode<NumElementsAttrs> {
DataType dtype;

TVM_DECLARE_ATTRS(SizeAttrs, "relay.attrs.SizeAttrs") {
TVM_DECLARE_ATTRS(NumElementsAttrs, "relay.attrs.NumElementsAttrs") {
TVM_ATTR_FIELD(dtype)
.describe("Target data type")
.set_default(NullValue<DataType>());
Expand Down
18 changes: 9 additions & 9 deletions src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,34 +280,34 @@ RELAY_REGISTER_OP("shape_of")
.set_attr<FTVMCompute>("FTVMCompute", ShapeOfCompute);


TVM_REGISTER_NODE_TYPE(SizeAttrs);
TVM_REGISTER_NODE_TYPE(NumElementsAttrs);

bool SizeRel(const Array<Type>& types,
bool NumElementsRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(num_inputs, 1);
auto tt = types[0].as<TensorTypeNode>();
CHECK(tt != nullptr);
const auto* param = attrs.as<SizeAttrs>();
const auto* param = attrs.as<NumElementsAttrs>();
CHECK(param != nullptr);
reporter->Assign(types[1], TensorTypeNode::make({1}, param->dtype));
return true;
}

Array<Tensor> SizeCompute(const Attrs& attrs,
Array<Tensor> NumElementsCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
CHECK_EQ(inputs.size(), 1);
const auto* param = attrs.as<SizeAttrs>();
const auto* param = attrs.as<NumElementsAttrs>();
CHECK(param != nullptr);
return Array<Tensor>{topi::size(inputs[0], param->dtype)};
}

TVM_REGISTER_API("relay.op.contrib._make.size")
.set_body_typed<Expr(Expr, DataType)>([](Expr data, DataType dtype) {
auto attrs = make_node<SizeAttrs>();
auto attrs = make_node<NumElementsAttrs>();
attrs->dtype = dtype;
static const Op& op = Op::Get("contrib.num_elements");
return CallNode::make(op, {data}, Attrs(attrs), {});
Expand All @@ -318,15 +318,15 @@ RELAY_REGISTER_OP("contrib.num_elements")
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.SizeAttrs")
.set_attrs_type_key("relay.attrs.NumElementsAttrs")
.add_argument("data", "Tensor", "The input tensor.")
.add_type_rel("Size", SizeRel)
.add_type_rel("NumElements", NumElementsRel)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_support_level(10)
.set_attr<FTVMCompute>("FTVMCompute", SizeCompute);
.set_attr<FTVMCompute>("FTVMCompute", NumElementsCompute);

} // namespace relay
} // namespace tvm
2 changes: 1 addition & 1 deletion tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def test_shape_of():
def test_size():
def verify_size(shape):
x = relay.var("x", shape=shape)
func = relay.Function([x], relay.op.contrib.tensor_size(x))
func = relay.Function([x], relay.op.contrib.num_elements(x))
func = relay.ir_pass.infer_type(func)

x_data = np.random.uniform(size=shape).astype("float32")
Expand Down
1 change: 0 additions & 1 deletion topi/include/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1175,7 +1175,6 @@ inline Tensor size(const Tensor& src,
Array<Expr> out_size;
out_size.push_back(1);
return compute(out_size, [&](const Array<Var>& indices) {
auto idx = indices[0];
Expr ret = 1;
for (int i = 0; i < ndim; ++i) {
ret *= src->shape[i];
Expand Down

0 comments on commit 48d69ac

Please sign in to comment.