Skip to content

Commit

Permalink
[TOPI] [Relay] Sparse Conv2d Implementation for 3x3 kernels (apache#8605
Browse files Browse the repository at this point in the history
)

* [topi] add spconv2d_3x3 nhwc

* [relay] sparse_conv2d: add kernel_size attr

* [relay] add strategy for spconv2d_3x3 nhwc

* [relay] pass to convert spconv2d with const args

* [relay] convert sparse conv2d pass fixes

* use array for sparse conv2d attr

* fixup 1x1 tests; new 3x3 tests
  • Loading branch information
Tantalus13A98B5F authored and Andrew Zhao Luo committed Sep 1, 2021
1 parent 3862ce6 commit 934b4e5
Show file tree
Hide file tree
Showing 12 changed files with 548 additions and 50 deletions.
4 changes: 4 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -1066,12 +1066,16 @@ struct SparseTransposeAttrs : public tvm::AttrsNode<SparseTransposeAttrs> {
/*! \brief Attributes for sparse_dense operator */
struct SparseConv2DAttrs : public tvm::AttrsNode<SparseConv2DAttrs> {
std::string layout;
Array<IndexExpr> kernel_size;

TVM_DECLARE_ATTRS(SparseConv2DAttrs, "relay.attrs.SparseConv2DAttrs") {
TVM_ATTR_FIELD(layout).set_default("NHWC").describe(
"Dimension ordering of input data. Can be 'NCHW', 'NHWC'"
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(kernel_size)
.set_default(Array<IndexExpr>{1, 1})
.describe("Kernel size for SparseConv2D, 1x1 or 3x3. ");
}
};

Expand Down
15 changes: 8 additions & 7 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,14 @@ def ref_input(self):

@ref_input.setter
def ref_input(self, val):
warnings.warn(
"You are specifying fixed input for tuning the operator. "
"Be sure your input always fits the operator. Some "
"operators may conduct layout transformation during tuning, "
"thus can lead to unexpected behaviors. ",
RuntimeWarning,
)
if val is not None:
warnings.warn(
"You are specifying fixed input for tuning the operator. "
"Be sure your input always fits the operator. Some "
"operators may conduct layout transformation during tuning, "
"thus can lead to unexpected behaviors. ",
RuntimeWarning,
)
self._ref_input = val

def set_task(self, task):
Expand Down
58 changes: 38 additions & 20 deletions python/tvm/relay/analysis/sparse_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def _search_conv2d_op_weight(expr):
return _ffi_api.search_conv2d_op_weight(expr)


def process_params(expr, params, block_size, sparsity_threshold, layout):
def process_params(
expr, params, block_size, sparsity_threshold, layout, kernel_size, reg_task_input=True
):
"""Process parameters of conv2d from dense to sparse.
Parameters
Expand Down Expand Up @@ -86,14 +88,18 @@ def process_params(expr, params, block_size, sparsity_threshold, layout):
for name in weight_names:
name = str(name)
w_np = params[name].numpy()
# currently only support conv2d_1*1
if not (
(w_np.shape[0] == 1 and w_np.shape[1] == 1)
or (w_np.shape[2] == 1 and w_np.shape[3] == 1)
):

if layout == "NHWC": # HWIO
weight_kernel = (w_np.shape[0], w_np.shape[1])
elif layout == "NCHW": # OIHW
weight_kernel = (w_np.shape[2], w_np.shape[3])
if weight_kernel[0] != weight_kernel[1]:
continue
sparsity = 1.0 - (np.count_nonzero(w_np) / w_np.size)
if sparsity >= sparsity_threshold:

if weight_kernel[0] == kernel_size == 1:
sparsity = 1.0 - (np.count_nonzero(w_np) / w_np.size)
if sparsity < sparsity_threshold:
continue
if layout == "NHWC":
w_np = w_np.squeeze().T
elif layout == "NCHW":
Expand All @@ -108,19 +114,31 @@ def process_params(expr, params, block_size, sparsity_threshold, layout):
)
else:
sparse_weight_data = sparse_weight.data
elif weight_kernel[0] == kernel_size == 3:
if layout == "NHWC": # HWIO
w_np = w_np.reshape((-1, w_np.shape[-1])).T
elif layout == "NCHW": # OIHW
w_np = w_np.reshape((w_np.shape[0], -1))
sparse_weight = sp.bsr_matrix(w_np, blocksize=block_size)
if 1 - (sparse_weight.nnz / w_np.size) < sparsity_threshold:
continue
sparse_weight_data = sparse_weight.data
else:
continue

# remove dense weight
del params[name]
memo.weight_name.append(name)
memo.weight_shape.append(
list(sparse_weight_data.shape)
+ list(sparse_weight.indices.shape)
+ list(sparse_weight.indptr.shape)
)
params[name + ".data"] = tvm.nd.array(sparse_weight_data)
params[name + ".indices"] = tvm.nd.array(sparse_weight.indices)
params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr)

# remove dense weight
del params[name]
memo.weight_name.append(name)
memo.weight_shape.append(
list(sparse_weight_data.shape)
+ list(sparse_weight.indices.shape)
+ list(sparse_weight.indptr.shape)
)
params[name + ".data"] = tvm.nd.array(sparse_weight_data)
params[name + ".indices"] = tvm.nd.array(sparse_weight.indices)
params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr)

if reg_task_input:
prefix = "sparse_conv2d_bsr_%d_%d_%d_%d_%d_%d_" % (
w_np.shape[0],
w_np.shape[1],
Expand Down
44 changes: 40 additions & 4 deletions python/tvm/relay/data_dep_optimization/bsr_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from .utils import _run_opt_pass


def convert(func, params, blocksize, sparsity_threshold, layout="NHWC"):
"""Convert a dense func and according parameters to block sparse
def convert(func, params, blocksize, sparsity_threshold, layout="NHWC", kernel_size=1):
"""Convert a conv2d func and according parameters to block sparse
Parameters
----------
Expand All @@ -49,10 +49,46 @@ def convert(func, params, blocksize, sparsity_threshold, layout="NHWC"):
params: Dict[Srting, tvm.nd.array]
New params with BSR matrix for mutated Expr
"""
weight_info = process_params(func, params, blocksize, sparsity_threshold, layout)
weight_info = process_params(func, params, blocksize, sparsity_threshold, layout, kernel_size)
new_func = _run_opt_pass(
func,
relay.transform.Conv2dToSparse(weight_info.weight_name, weight_info.weight_shape, layout),
relay.transform.Conv2dToSparse(
weight_info.weight_name, weight_info.weight_shape, layout, kernel_size
),
)

return new_func, params


def convert2(func, params, blocksize, sparsity_threshold, layout, kernel_size):
"""Convert a freezed conv2d func to block sparse
Parameters
----------
func : relay.Expr
Expr will be optimized to sparse operation, with params freezed
params : Dict[Srting, tvm.nd.array]
Parameters of the Expr (not used in this pass)
blocksize : Tuple(int, int)
Blocksize for BSR matrix
sparsity_threshold : float
Minimal sparsity requirement for converting.
If weight sparsity is lower than this threshold,
the dense operation will be kept.
layout : str
layout of network
kernel_size : int
kernel size of the conv2d, for filtering
Returns
-------
new_func: relay.Expr
Mutated Expr with sparse operations
params: Dict[Srting, tvm.nd.array]
New params with BSR matrix for mutated Expr (not modified)
"""
new_func = _run_opt_pass(
func, relay.transform.Conv2dToSparse2(layout, kernel_size, blocksize, sparsity_threshold)
)
return new_func, params
6 changes: 5 additions & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,11 @@ def compute_sparse_transpose(attrs, inputs, out_type):
@reg.register_compute("nn.sparse_conv2d")
def compute_sparse_conv2d(attrs, inputs, out_type):
"""Compute definition of sparse_conv2d"""
return [topi.nn.sparse_conv2d(inputs[0], inputs[1], inputs[2], inputs[3], attrs["layout"])]
return [
topi.nn.sparse_conv2d(
inputs[0], inputs[1], inputs[2], inputs[3], attrs["layout"], attrs["kernel_size"]
)
]


reg.register_strategy("nn.sparse_conv2d", strategy.sparse_conv2d_strategy)
Expand Down
25 changes: 25 additions & 0 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,31 @@ def sparse_dense_strategy_cpu(attrs, inputs, out_type, target):
return strategy


@sparse_conv2d_strategy.register("cpu")
def sparse_conv2d_strategy_cpu(attrs, inputs, out_type, target):
"""sparse conv2d x86 strategy"""
strategy = _op.OpStrategy()
if attrs["kernel_size"][0] == 1:
strategy.add_implementation(
wrap_compute_sparse_conv2d(topi.nn.sparse_conv2d),
wrap_topi_schedule(topi.generic.schedule_sparse_conv2d),
name="sparse_conv2d.generic",
)
elif attrs["kernel_size"][0] == 3:
if attrs["layout"] == "NHWC":
strategy.add_implementation(
wrap_compute_sparse_conv2d(topi.x86.spconv2d_3x3_nhwc),
wrap_topi_schedule(topi.x86.schedule_spconv2d_3x3_nhwc),
name="conv3x3_spNHWC.x86",
)
elif attrs["layout"] == "NCHW":
strategy.add_implementation(
wrap_compute_sparse_conv2d(topi.x86.spconv2d_3x3_nchw),
wrap_topi_schedule(topi.x86.schedule_spconv2d_3x3_nchw),
)
return strategy


@roi_align_strategy.register("cpu")
def roi_align_strategy_cpu(attrs, inputs, out_type, target):
"""roi_align x86 strategy"""
Expand Down
24 changes: 22 additions & 2 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,7 @@ def DenseToSparse(weight_name, weight_shape):
return _ffi_api.DenseToSparse(weight_name, weight_shape)


def Conv2dToSparse(weight_name, weight_shape, layout):
def Conv2dToSparse(weight_name, weight_shape, layout, kernel_size):
"""
Rewrite qualified ```nn.conv2d operation``` to ```nn.sparse_conv2d```
Expand All @@ -1113,7 +1113,27 @@ def Conv2dToSparse(weight_name, weight_shape, layout):
ret : tvm.transform.Pass
The registered DenseToSparse pass.
"""
return _ffi_api.Conv2dToSparse(weight_name, weight_shape, layout)
return _ffi_api.Conv2dToSparse(weight_name, weight_shape, layout, kernel_size)


def Conv2dToSparse2(layout, kernel_size, blocksize, sparsity_threshold):
"""
Rewrite freezed ```nn.conv2d``` operation to ```nn.sparse_conv2d```
Parameters
----------
layout : str
layout of data
kernel_size : int
kernel size of conv2d
Returns
-------
ret : tvm.transform.Pass
The registered DenseToSparse pass.
"""
return _ffi_api.Conv2dToSparse2(layout, kernel_size, *blocksize, sparsity_threshold)


def SimplifyFCTranspose(target_weight_name):
Expand Down
21 changes: 12 additions & 9 deletions python/tvm/topi/nn/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,9 @@ def _compute_block(i, nb_j, j, h, w): # pylint: disable=C0103
)


def sparse_conv2d(dense_data, sparse_data, sparse_indices, sparse_indptr, layout="NHWC"):
def sparse_conv2d(
dense_data, sparse_data, sparse_indices, sparse_indptr, layout="NHWC", kernel_size=1
):
"""
Computes sparse-conv2d(1*1) of ``data`` and
``(weight_data, weight_indices, weight_indptr)``
Expand Down Expand Up @@ -598,14 +600,15 @@ def sparse_conv2d(dense_data, sparse_data, sparse_indices, sparse_indptr, layout
4-D with shape [M, H, W, N] (layout=NHWC)
4-D with shape [M, N, H ,W] (layout=NCHW)
"""
if layout == "NHWC":
return _sparse_conv2d_bsr_compute_nhwc(
dense_data, sparse_data, sparse_indices, sparse_indptr
)
elif layout == "NCHW":
return _sparse_conv2d_bsr_compute_nchw(
dense_data, sparse_data, sparse_indices, sparse_indptr
)
if kernel_size == 1:
if layout == "NHWC":
return _sparse_conv2d_bsr_compute_nhwc(
dense_data, sparse_data, sparse_indices, sparse_indptr
)
elif layout == "NCHW":
return _sparse_conv2d_bsr_compute_nchw(
dense_data, sparse_data, sparse_indices, sparse_indptr
)
else:
raise ValueError("Unsupport Layout %s" % layout)

Expand Down
Loading

0 comments on commit 934b4e5

Please sign in to comment.