Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] x_dot_x builtin kernel support #831

Merged
merged 50 commits into from
Sep 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
6219fe6
upd
yzh119 Aug 6, 2019
66541cc
fig edgebatch edges
yzh119 Aug 7, 2019
8258153
add test
yzh119 Aug 7, 2019
21404bb
Merge remote-tracking branch 'upstream/master' into fix
yzh119 Aug 7, 2019
fb62f10
trigger
yzh119 Aug 7, 2019
ba138a3
Merge remote-tracking branch 'upstream/master'
yzh119 Aug 12, 2019
d725786
Merge remote-tracking branch 'upstream/master'
yzh119 Aug 15, 2019
e291d6d
Update README.md for pytorch PinSage example.
Aug 16, 2019
eab01bb
Merge branch 'master' of https://github.com/classicsong/dgl
Aug 19, 2019
5fdc289
Provid a frame agnostic API to test nn modules on both CPU and CUDA s…
Aug 19, 2019
2e89c6f
Fix style
classicsong Aug 19, 2019
b1af382
Delete unused code
classicsong Aug 19, 2019
85630e3
Make agnostic test only related to tests/backend
classicsong Aug 19, 2019
874352f
Fix code style
classicsong Aug 19, 2019
b918c9b
Merge remote-tracking branch 'upstream/master'
yzh119 Aug 19, 2019
47e468a
fix
yzh119 Aug 19, 2019
0769269
Merge remote-tracking branch 'zihao/fix-nn'
classicsong Aug 19, 2019
10e1d27
doc
yzh119 Aug 19, 2019
e1b4864
Merge remote-tracking branch 'zihao/fix-nn'
classicsong Aug 19, 2019
4bfe71c
Make all test code under tests.mxnet/pytorch.test_nn.py
classicsong Aug 19, 2019
a91b1bb
Fix syntex
classicsong Aug 19, 2019
475c0c3
Merge branch 'master' into master
yzh119 Aug 21, 2019
edf6a0e
Remove rand
classicsong Aug 21, 2019
ac8f6e4
Merge branch 'master' of https://github.com/classicsong/dgl
classicsong Aug 21, 2019
b86e8db
Start implementing masked-mm kernel.
classicsong Sep 3, 2019
a0a55ec
Add masked dot declare
classicsong Sep 3, 2019
f27fddc
Update func/variable name
Sep 3, 2019
d12d565
Merge branch 'masked-mm' of https://github.com/classicsong/dgl into m…
Sep 3, 2019
d882599
Skeleton compile OK
Sep 3, 2019
190102b
Update Implement. Unify BinaryDot with BinaryReduce
classicsong Sep 4, 2019
4fdb8fe
New Impl of x_dot_x, reuse binary reduce template
classicsong Sep 4, 2019
faa3b2d
Compile OK.
Sep 4, 2019
edd378c
Merge branch 'master' into masked-mm
classicsong Sep 5, 2019
f9a0676
Fix code style
Sep 5, 2019
2bd6097
Merge branch 'masked-mm' of https://github.com/classicsong/dgl into m…
Sep 5, 2019
9c00adc
Now we can pass the tests/compute/test_kernel.py for add/sub/mul/div …
Sep 5, 2019
037b142
Fix mxnet test code
Sep 5, 2019
c59de93
Add u_dot_v, u_dot_e, v_dot_e unitest.
Sep 5, 2019
bed5bcd
Update doc
classicsong Sep 5, 2019
056cbf5
Now also support v_dot_u, e_dot_u, e_dot_v
Sep 5, 2019
9e94b31
Add unroll for some loop
Sep 5, 2019
bdd3ff1
Merge branch 'master' into masked-mm
classicsong Sep 5, 2019
89e61dd
Add some Opt for cuda backward of dot builtin.
Sep 6, 2019
fea7cbb
Merge branch 'masked-mm' of github.com:classicsong/dgl into masked-mm
Sep 6, 2019
73c5642
Merge branch 'master' into masked-mm
yzh119 Sep 7, 2019
a78edfb
Merge branch 'master' into masked-mm
classicsong Sep 9, 2019
c96338a
Apply UnravelRavel opt for broadcast backward
Sep 9, 2019
cb7d1ac
update docstring
classicsong Sep 11, 2019
63563fe
Merge branch 'master' into masked-mm
VoVAllen Sep 11, 2019
b065a4e
Merge branch 'master' into masked-mm
yzh119 Sep 12, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/api/python/function.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ Message functions
e_sub_v
e_mul_v
e_div_v
u_dot_v
u_dot_e
v_dot_e
v_dot_u
e_dot_u
e_dot_v

Reduce functions
----------------
Expand Down
4 changes: 4 additions & 0 deletions docs/source/features/builtin.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ Here is a cheatsheet of all the DGL builtins.
| +----------------------------------------------------+-----------------------+
| | ``e_add_v``, ``e_sub_v``, ``e_mul_v``, ``e_div_v`` | |
| +----------------------------------------------------+-----------------------+
| | ``u_dot_v``, ``u_dot_e``, ``v_dot_e`` | |
| +----------------------------------------------------+-----------------------+
| | ``v_dot_u``, ``e_dot_u``, ``e_dot_v`` | |
| +----------------------------------------------------+-----------------------+
| | ``src_mul_edge`` | alias of ``u_mul_e`` |
+-------------------------+----------------------------------------------------+-----------------------+
| Reduce function | ``max`` | |
Expand Down
11 changes: 7 additions & 4 deletions python/dgl/backend/mxnet/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,11 @@ def __init__(self, reducer, binary_op, graph, lhs, rhs, out_size, lhs_map,
def forward(self, lhs_data, rhs_data):
lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
feat_shape = K.infer_binary_feature_shape(lhs_data_nd, rhs_data_nd)
out_data = nd.empty((self.out_size,) + feat_shape,
feat_shape = K.infer_binary_feature_shape(self.binary_op, lhs_data_nd, rhs_data_nd)
out_shape = feat_shape
if self.binary_op == 'dot':
out_shape = feat_shape[:-1]
out_data = nd.empty((self.out_size,) + out_shape,
ctx=lhs_data.context, dtype=lhs_data.dtype)
out_data_nd = zerocopy_to_dgl_ndarray_for_write(out_data)
K.binary_op_reduce(
Expand All @@ -402,10 +405,10 @@ def forward(self, lhs_data, rhs_data):
in_ones = nd.ones((n,), ctx=lhs_data.context, dtype=lhs_data.dtype)
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
K.copy_reduce(
'sum', self.graph, target, in_ones_nd, degs_nd,
'sum', self.graph, target, in_ones_nd, degs_nd,
in_map, self.out_map[0])
# reshape
degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.ndim - 1)).clip(1, float('inf'))
degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.ndim - 1)).clip(1, float('inf'))
out_data = out_data / degs
else:
degs = None
Expand Down
11 changes: 7 additions & 4 deletions python/dgl/backend/pytorch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,14 @@ def forward(ctx, reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data,
out_size, lhs_map, rhs_map, out_map):
lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
feat_shape = K.infer_binary_feature_shape(lhs_data_nd, rhs_data_nd)
out_data = lhs_data.new_empty((out_size,) + feat_shape)
feat_shape = K.infer_binary_feature_shape(binary_op, lhs_data_nd, rhs_data_nd)
out_shape = feat_shape
if binary_op == 'dot':
out_shape = feat_shape[:-1]
out_data = lhs_data.new_empty((out_size,) + out_shape)
out_data_nd = zerocopy_to_dgl_ndarray(out_data)
K.binary_op_reduce(
reducer if reducer != 'mean' else 'sum',
reducer if reducer != 'mean' else 'sum',
binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd,
out_data_nd, lhs_map[0], rhs_map[0], out_map[0])
# normalize if mean reducer
Expand All @@ -311,7 +314,7 @@ def forward(ctx, reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data,
in_ones = lhs_data.new_ones((n,))
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
K.copy_reduce(
'sum', graph, target, in_ones_nd, degs_nd, in_map, out_map[0])
'sum', graph, target, in_ones_nd, degs_nd, in_map, out_map[0])
# reshape
degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.dim() - 1)).clamp(min=1)
out_data = out_data / degs
Expand Down
8 changes: 6 additions & 2 deletions python/dgl/function/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,17 @@ def copy_e(e, out):

###############################################################################
# Generate all following builtin message functions:
# element-wise message functions:
# u_add_v, u_sub_v, u_mul_v, u_div_v
# u_add_e, u_sub_e, u_mul_e, u_div_e
# v_add_u, v_sub_u, v_mul_u, v_div_u
# v_add_e, v_sub_e, v_mul_e, v_div_e
# e_add_u, e_sub_u, e_mul_u, e_div_u
# e_add_v, e_sub_v, e_mul_v, e_div_v
#
# dot message functions:
# u_dot_v, u_dot_e, v_dot_e
yzh119 marked this conversation as resolved.
Show resolved Hide resolved
# v_dot_u, e_dot_u, e_dot_v

_TARGET_MAP = {
"u": TargetCode.SRC,
Expand Down Expand Up @@ -200,12 +205,11 @@ def _register_builtin_message_func():
target = ["u", "v", "e"]
for lhs, rhs in product(target, target):
if lhs != rhs:
for binary_op in ["add", "sub", "mul", "div"]:
for binary_op in ["add", "sub", "mul", "div", "dot"]:
func = _gen_message_builtin(lhs, rhs, binary_op)
setattr(sys.modules[__name__], func.__name__, func)
__all__.append(func.__name__)


_register_builtin_message_func()


Expand Down
7 changes: 5 additions & 2 deletions python/dgl/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
from ._ffi.function import _init_api
from .ndarray import empty

def infer_binary_feature_shape(lhs, rhs):
# pylint: disable=invalid-name
def infer_binary_feature_shape(op, lhs, rhs):
"""Infer the output feature shape after a binary operation between lhs and rhs.

Parameter
yzh119 marked this conversation as resolved.
Show resolved Hide resolved
---------
op : string
The binary_op name.
lhs : dgl.ndarray.NDArray
The lhs tensor.
rhs : dgl.ndarray.NDArray
Expand All @@ -19,7 +22,7 @@ def infer_binary_feature_shape(lhs, rhs):
tuple of int
The output feature shape.
"""
ret = _CAPI_DGLKernelInferBinaryFeatureShape(lhs, rhs)
ret = _CAPI_DGLKernelInferBinaryFeatureShape(op, lhs, rhs)
return tuple(ret.asnumpy())

# pylint: disable=invalid-name
Expand Down
34 changes: 25 additions & 9 deletions src/kernel/binary_reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,25 @@ bool HasBcast(NDArray lhs, NDArray rhs) {
// e.g. (4, 1, 3, 3) and (4, 5, 3, 3) become (4, 1, 9) and (4, 5, 9)
//
// See also: BcastInfo (kernel/binary_reduce.h)
BcastInfo CalcBcastInfo(NDArray lhs, NDArray rhs) {
BcastInfo CalcBcastInfo(const std::string& op, NDArray lhs, NDArray rhs) {
BcastInfo ret;
const int max_ndim = std::max(lhs->ndim, rhs->ndim) - 1;
int64_t accum = 0;
for (int j = 0; j < max_ndim; ++j) {
int j = 0;
// for dot operation: vector [dot] vector
// lhs_shape[ndim-1] == rhs_shape[ndim-1] = sizeof(vector)
// out_shape[ndim-1] = 1
if (op == binary_op::kDot) {
// get size of vector
ret.data_len = lhs->shape[lhs->ndim - 1];
// skip vector size dim
++j;
ret.real_out_shape.push_back(ret.data_len);
} else { // op != binary_op::kDot
ret.data_len = 1;
}

for (; j < max_ndim; ++j) {
const int dl = (lhs->ndim - 1 - j < 1)? 1 : lhs->shape[lhs->ndim - 1 - j];
const int dr = (rhs->ndim - 1 - j < 1)? 1 : rhs->shape[rhs->ndim - 1 - j];
if (dl != dr) {
Expand Down Expand Up @@ -258,16 +272,18 @@ class BipartiteCSRWrapper : public CSRWrapper {


std::vector<int64_t> InferBinaryFeatureShape(
const std::string& op,
NDArray lhs,
NDArray rhs) {
return CalcBcastInfo(lhs, rhs).real_out_shape;
return CalcBcastInfo(op, lhs, rhs).real_out_shape;
}

DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelInferBinaryFeatureShape")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArray lhs = args[0];
NDArray rhs = args[1];
const auto& shape = InferBinaryFeatureShape(lhs, rhs);
std::string op = args[0];
NDArray lhs = args[1];
NDArray rhs = args[2];
const auto& shape = InferBinaryFeatureShape(op, lhs, rhs);
const int64_t len = shape.size();
NDArray ret = NDArray::Empty(
{len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
Expand Down Expand Up @@ -300,7 +316,7 @@ void BinaryOpReduce(
rhs_mapping, lhs_mapping, out_mapping);
} else {
if (HasBcast(lhs_data, rhs_data)) {
BcastInfo info = CalcBcastInfo(lhs_data, rhs_data);
BcastInfo info = CalcBcastInfo(op, lhs_data, rhs_data);
DGL_XPU_SWITCH(ctx.device_type, BinaryReduceBcastImpl,
info, reducer, op, graph,
lhs, rhs,
Expand Down Expand Up @@ -394,7 +410,7 @@ void BackwardLhsBinaryOpReduce(
grad_out_data, grad_lhs_data);
} else {
if (HasBcast(lhs_data, rhs_data)) {
BcastInfo info = CalcBcastInfo(lhs_data, rhs_data);
BcastInfo info = CalcBcastInfo(op, lhs_data, rhs_data);
DGL_XPU_SWITCH(ctx.device_type, BackwardBinaryReduceBcastImpl,
info, reducer, op, graph,
lhs, rhs,
Expand Down Expand Up @@ -468,7 +484,7 @@ void BackwardRhsBinaryOpReduce(
grad_out_data, grad_rhs_data);
} else {
if (HasBcast(lhs_data, rhs_data)) {
BcastInfo info = CalcBcastInfo(lhs_data, rhs_data);
BcastInfo info = CalcBcastInfo(op, lhs_data, rhs_data);
DGL_XPU_SWITCH(ctx.device_type, BackwardBinaryReduceBcastImpl,
info, reducer, op, graph,
lhs, rhs,
Expand Down
2 changes: 2 additions & 0 deletions src/kernel/binary_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ struct BcastInfo {
std::vector<int64_t> lhs_shape, lhs_stride;
std::vector<int64_t> rhs_shape, rhs_stride;
std::vector<int64_t> out_shape, out_stride;

int64_t data_len;
};

/*
Expand Down
80 changes: 69 additions & 11 deletions src/kernel/binary_reduce_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ static const char kAdd[] = "add";
static const char kSub[] = "sub";
static const char kMul[] = "mul";
static const char kDiv[] = "div";
static const char kDot[] = "dot";
static const char kUseLhs[] = "use_lhs";

/*!
Expand Down Expand Up @@ -129,8 +130,8 @@ struct SwitchSrcDst<SelectDst> {
// common binary functors
template <typename DType>
struct BinaryAdd {
static DGLDEVICE DGLINLINE DType Call(DType lhs, DType rhs) {
return lhs + rhs;
static DGLDEVICE DGLINLINE DType Call(DType *lhs, DType *rhs, int64_t len) {
return lhs[0] + rhs[0];
}
static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) {
return 1;
Expand All @@ -142,8 +143,8 @@ struct BinaryAdd {

template <typename DType>
struct BinaryMul {
static DGLDEVICE DGLINLINE DType Call(DType lhs, DType rhs) {
return lhs * rhs;
static DGLDEVICE DGLINLINE DType Call(DType *lhs, DType *rhs, int64_t len) {
return lhs[0] * rhs[0];
}
static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) {
return rhs;
Expand All @@ -155,8 +156,8 @@ struct BinaryMul {

template <typename DType>
struct BinarySub {
static DGLDEVICE DGLINLINE DType Call(DType lhs, DType rhs) {
return lhs - rhs;
static DGLDEVICE DGLINLINE DType Call(DType *lhs, DType *rhs, int64_t len) {
return lhs[0] - rhs[0];
}
static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) {
return 1;
Expand All @@ -168,8 +169,8 @@ struct BinarySub {

template <typename DType>
struct BinaryDiv {
static DGLDEVICE DGLINLINE DType Call(DType lhs, DType rhs) {
return lhs / rhs;
static DGLDEVICE DGLINLINE DType Call(DType *lhs, DType *rhs, int64_t len) {
return lhs[0] / rhs[0];
}
static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) {
return static_cast<DType>(1) / rhs;
Expand All @@ -181,8 +182,8 @@ struct BinaryDiv {

template <typename DType>
struct BinaryUseLhs {
static DGLDEVICE DGLINLINE DType Call(DType lhs, DType rhs) {
return lhs;
static DGLDEVICE DGLINLINE DType Call(DType *lhs, DType *rhs, int64_t len) {
return lhs[0];
}
static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) {
return 1;
Expand All @@ -192,6 +193,25 @@ struct BinaryUseLhs {
}
};

template <typename DType>
struct BinaryDot {
static DGLDEVICE DGLINLINE DType Call(DType *lhs, DType *rhs, int64_t len) {
DType out = 0;
// simple vector dot vector
yzh119 marked this conversation as resolved.
Show resolved Hide resolved
#pragma unroll
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is already a pragma unroll in the graph level in minigun. Nested pragma unroll usually does not give benefit. Consider remove this.

for (int i = 0; i < len; i ++)
out += lhs[i] * rhs[i];

return out;
}
static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) {
return rhs;
}
static DGLDEVICE DGLINLINE DType BackwardRhs(DType lhs, DType rhs, DType out) {
return lhs;
}
};

// Macro for dispatching op enum code and target code into template arguments.
// The macro dispatches following combinations:
// - Add(Src, Dst), Add(Src, Edge), Add(Dst, Edge)
Expand All @@ -201,6 +221,8 @@ struct BinaryUseLhs {
// - Div(Src, Dst), Div(Src, Edge), Div(Dst, Edge)
// Div(Dst, Src), Div(Edge, Src), Div(Edge, Dst)
// - UseLhs(Src, None), UseLhs(Edge, None)
// - Dot(Src, Dst), Dot(Src, Edge), Dot(Dst, Edge)
// - Dot(Dst, Src), Dot(Edge, Src), Dot(Edge, Dst)
// Note that for commutative operators (e.g. Add and Mul), we only generate
// kernels for lhs code smaller than rhs code.
#define OP_TARGET_SWITCH(op, lhs, rhs, DType, OpType, LeftType, RightType, ...) \
Expand Down Expand Up @@ -306,6 +328,36 @@ struct BinaryUseLhs {
typedef SelectEdge LeftType; \
typedef SelectNone RightType; \
{__VA_ARGS__} \
} else if (op == kDot && lhs == kSrc && rhs == kDst) { \
typedef BinaryDot<DType> OpType; \
typedef SelectSrc LeftType; \
typedef SelectDst RightType; \
{__VA_ARGS__} \
} else if (op == kDot && lhs == kSrc && rhs == kEdge) { \
typedef BinaryDot<DType> OpType; \
typedef SelectSrc LeftType; \
typedef SelectEdge RightType; \
{__VA_ARGS__} \
} else if (op == kDot && lhs == kDst && rhs == kEdge) { \
typedef BinaryDot<DType> OpType; \
typedef SelectDst LeftType; \
typedef SelectEdge RightType; \
{__VA_ARGS__} \
} else if (op == kDot && lhs == kDst && rhs == kSrc) { \
typedef BinaryDot<DType> OpType; \
typedef SelectDst LeftType; \
typedef SelectSrc RightType; \
{__VA_ARGS__} \
} else if (op == kDot && lhs == kEdge && rhs == kSrc) { \
typedef BinaryDot<DType> OpType; \
typedef SelectEdge LeftType; \
typedef SelectSrc RightType; \
{__VA_ARGS__} \
} else if (op == kDot && lhs == kEdge && rhs == kDst) { \
typedef BinaryDot<DType> OpType; \
typedef SelectEdge LeftType; \
typedef SelectDst RightType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "Unsupported operation: op=" << op \
<< " lhs=" << lhs << " rhs=" << rhs; \
Expand Down Expand Up @@ -333,7 +385,13 @@ struct BinaryUseLhs {
MSVC_EXPAND(GEN(__VA_ARGS__, SelectDst, SelectEdge, BinaryDiv)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectEdge, SelectDst, BinaryDiv)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectSrc, SelectNone, BinaryUseLhs)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectEdge, SelectNone, BinaryUseLhs))
MSVC_EXPAND(GEN(__VA_ARGS__, SelectEdge, SelectNone, BinaryUseLhs)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectSrc, SelectDst, BinaryDot)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectSrc, SelectEdge, BinaryDot)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectDst, SelectEdge, BinaryDot)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectDst, SelectSrc, BinaryDot)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectEdge, SelectSrc, BinaryDot)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectEdge, SelectDst, BinaryDot))

//////////////////////////////////////////////////////////////////////////
// Defines reducer category. Each category is an empty structure.
Expand Down
Loading