diff --git a/topi/python/topi/mali/conv2d.py b/topi/python/topi/mali/conv2d.py index ff67e0503f4f..5b4cf5bae6ff 100644 --- a/topi/python/topi/mali/conv2d.py +++ b/topi/python/topi/mali/conv2d.py @@ -2,6 +2,8 @@ """conv2d schedule on ARM Mali GPU""" from __future__ import absolute_import as _abs + +import numpy as np import tvm from .. import generic @@ -63,7 +65,23 @@ def transpose(s, tensor, readers): s[tmp].compute_inline() return s.cache_write(tmp, "global"), tmp -@conv2d.register("mali") +def const_array(data, name): + """ convert an const array to tvm tensor""" + row, col = data.shape + dtype = str(data.dtype) + + def select_array(i, j): + now = tvm.const(0.0, dtype) + for ii in range(row): + for jj in range(col): + now = tvm.select(tvm.all(i % row == ii, j % col == jj), + tvm.const(data[ii][jj], dtype), + now) + return now + return tvm.compute(data.shape, select_array, name=name) + + +@conv2d.register(["mali"]) def decl_conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'): """Conv2D operator for ARM Mali GPU backend. @@ -94,10 +112,20 @@ def decl_conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype='float32 assert data.dtype == kernel.dtype, "Do not support inputs with different data types now." out_dtype = data.dtype - if util.get_const_int(kernel.shape[2]) == 1: + HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + kernel_shape = util.get_const_tuple(kernel.shape) + if isinstance(stride, (tuple, list)): + HSTR, WSTR = stride + else: + HSTR, WSTR = stride, stride + + if (kernel_shape[2:4] == (3, 3) and (HPAD, WPAD) == (1, 1) and kernel_shape[0] >= 64 and + (HSTR, WSTR) == (1, 1)): + return _decl_winograd(data, kernel, stride, padding, layout, out_dtype) + elif kernel_shape[2:4] == (1, 1): return _decl_im2col(data, kernel, stride, padding, layout, out_dtype) else: - return _decl_direct(data, kernel, stride, padding, layout, out_dtype) + return _decl_spatialpack(data, kernel, stride, padding, layout, out_dtype) @generic.schedule_conv2d_nchw.register(["mali"]) def schedule_conv2d_nchw(outs): @@ -129,14 +157,17 @@ def traverse(op): if 'im2col_conv_output' in op.tag: _schedule_im2col_conv2d(s, op) - if 'direct_conv_output' in op.tag: - _schedule_direct_conv2d(s, op) + if 'spatialpack_conv_output' in op.tag: + _schedule_spatialpack_conv2d(s, op) + + if 'winograd_conv_output' in op.tag: + _schedule_winograd(s, op) traverse(outs[0].op) return s -def _decl_direct(data, kernel, stride, padding, layout, out_dtype): - """declare the direct method (spatial packing) for conv2d""" +def _decl_spatialpack(data, kernel, stride, padding, layout, out_dtype): + """declare the spatialpack method (spatial packing) for conv2d""" _, CI, IH, IW = [util.get_const_int(x) for x in data.shape] CO, _, KH, KW = [util.get_const_int(x) for x in kernel.shape] HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) @@ -207,12 +238,12 @@ def _decl_direct(data, kernel, stride, padding, layout, out_dtype): output = tvm.compute(oshape, lambda n, co, h, w: conv[n][co//VC][h/VH][w//VW][h%VH][w%VW][co%VC], - name='output_unpack', tag='direct_conv_output') + name='output_unpack', tag='spatialpack_conv_output') return output -def _schedule_direct_conv2d(s, op): - """schedule the direct method (spatial packing) for conv2d""" +def _schedule_spatialpack_conv2d(s, op): + """schedule the spatialpack method (spatial packing) for conv2d""" # get ops and tensors output = op.output(0) output_height = util.get_const_int(output.shape[2]) @@ -294,8 +325,6 @@ def _schedule_direct_conv2d(s, op): _, co, oh, ow = s[output].op.axis tile_and_bind3d(s, output, co, oh, ow, num_thread, 1, last) - #print(tvm.lower(s, [data, kernel, output], simple_mode=True)) - def _decl_im2col(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'): """declare the Im2Col method for conv2d""" _, CI, IH, IW = [x.value for x in data.shape] @@ -476,4 +505,174 @@ def _schedule_im2col_conv2d(s, op): s[output].vectorize(vw) fuse_and_bind(s, output, [n, co, h, w]) - #print(tvm.lower(s, [data, kernel], simple_mode=True)) +def _decl_winograd(data, kernel, stride, padding, layout, out_dtype): + """declare winograd fast convolution F(2x2, 3x3) for conv2d""" + N, CI, H, W = [util.get_const_int(x) for x in data.shape] + CO, CI, KH, KW = [util.get_const_int(x) for x in kernel.shape] + HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + if isinstance(stride, (tuple, list)): + HSTR, WSTR = stride + else: + HSTR, WSTR = stride, stride + + assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 and KH == 3 and KW == 3 + data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") + + B_data = np.array([ + [1, 0, 0, 0], + [0, 1, -1, 1], + [-1, 1, 1, 0], + [0, 0, 0, -1] + ], out_dtype) + + G_data = np.array([ + [1, 0, 0], + [1.0/2, 1.0/2, 1.0/2], + [1.0/2, -1.0/2, 1.0/2], + [0, 0, 1], + ], out_dtype) + + A_data = np.array([ + [1, 0], + [1, 1], + [1, -1], + [0, -1], + ], out_dtype) + + m = 2 + r = 3 + alpha = m + r - 1 + K = CO + C = CI + + nH, nW = (H + m-1) // m, (W + m-1) // m + P = N * nH * nW + + bna, bnb = 4, 4 + if data.dtype == 'float16': + bnb *= 2 + P_round = (P + bnb - 1) // bnb * bnb + assert K % bna == 0 and P_round % bnb == 0 + + # pack input tile + input_tile = tvm.compute((C, P_round // bnb, alpha, alpha, bnb), + lambda c, b, eps, nu, bb: + tvm.select(b * bnb + bb < P,\ + data_pad[(b*bnb+bb) // (nH*nW)][c][(b*bnb+bb) // nW % nH * m + eps]\ + [(b*bnb+bb) % nW * m + nu], tvm.const(0, data_pad.dtype)), + name='d') + + # transform kernel + G = const_array(G_data, 'G') + r_kh = tvm.reduce_axis((0, KH), 'r_kh') + r_kw = tvm.reduce_axis((0, KW), 'r_kw') + U = tvm.compute((alpha, alpha, K // bna, C, bna), lambda eps, nu, k, c, kk: + tvm.sum(kernel[k * bna + kk][c][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], + axis=[r_kh, r_kw]), name='U') + + # transform image + B = const_array(B_data, 'B') + r_eps = tvm.reduce_axis((0, alpha), 'r_eps') + r_nu = tvm.reduce_axis((0, alpha), 'r_nu') + V = tvm.compute((alpha, alpha, P_round // bnb, C, bnb), lambda eps, nu, b, c, bb: + tvm.sum(input_tile[c][b][r_eps][r_nu][bb] * B[r_eps][eps] * B[r_nu][nu], + axis=[r_eps, r_nu]), name='V') + + # batch gemm + c = tvm.reduce_axis((0, C), name='c') + M = tvm.compute((alpha, alpha, K, P_round), lambda eps, nu, k, b: + tvm.sum(U[eps][nu][k // bna][c][k % bna] * + V[eps][nu][b // bnb][c][b % bnb], axis=c), name='M') + + # inverse transform + A = const_array(A_data, 'A') + r_eps = tvm.reduce_axis((0, alpha), 'r_eps') + r_nu = tvm.reduce_axis((0, alpha), 'r_nu') + Y = tvm.compute((K, P, m, m), lambda k, b, vh, vw: + tvm.sum(M[r_eps][r_nu][k][b] * A[r_eps][vh] * A[r_nu][vw], + axis=[r_eps, r_nu]), name='Y') + + # unpack output + output = tvm.compute((N, K, H, W), lambda n, k, h, w: + Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m] + # thw following term is used to make the padding effective, + # otherwise the padding will be eliminated by bound inference + + tvm.const(0, out_dtype) * M[alpha-1][alpha-1][K-1][P_round-1], + name='output', tag='winograd_conv_output') + + return output + +def _schedule_winograd(s, op): + """schedule winograd fast convolution F(2x2, 3x3) for conv2d""" + + # get ops and tensors + output = op.output(0) + + Y = op.input_tensors[0] + M, A = s[Y].op.input_tensors + U, V = s[M].op.input_tensors + kernel, G = s[U].op.input_tensors + d, B = s[V].op.input_tensors + data_pad = s[d].op.input_tensors[0] + data = s[data_pad].op.input_tensors[0] + + # padding + s[data_pad].compute_inline() + + # pack input tiles + c, b, eps, nu, bb = s[d].op.axis + s[d].reorder(eps, nu, bb) + aha = s[d].fuse(eps, nu) + s[d].unroll(bb) + tile_and_bind3d(s, d, c, b, aha, 4, 1, 1) + + # transform kernel + s[G].compute_inline() + eps, nu, k, c, kk, = s[U].op.axis + r_kh, r_kw = s[U].op.reduce_axis + s[U].reorder(k, c, kk, eps, nu, r_kh, r_kw) + _ = [s[U].unroll(x) for x in [eps, nu, r_kh, r_kw]] + s[U].vectorize(kk) + tile_and_bind(s, U, k, c, 1, 256) + + # transform image + s[B].compute_inline() + eps, nu, b, c, bb = s[V].op.axis + r_eps, r_nu = s[V].op.reduce_axis + s[V].reorder(b, c, bb, eps, nu, r_nu, r_eps) + _ = [s[V].unroll(x) for x in [eps, nu, r_eps, r_nu]] + s[V].vectorize(bb) + tile_and_bind(s, V, b, c, 2, 1) + + # batch gemm + bna, bnb = 4, 4 + if data.dtype == 'float16': + bnb *= 2 + + eps, nu, k, b = s[M].op.axis + c = s[M].op.reduce_axis[0] + yo, xo, yi, xi = s[M].tile(k, b, bna, bnb) + s[M].reorder(c, yi, xi) + c, c_unroll = s[M].split(c, 2) + s[M].unroll(c_unroll) + s[M].unroll(yi) + s[M].vectorize(xi) + z = s[M].fuse(eps, nu) + tile_and_bind3d(s, M, z, yo, xo, 1, 8, 1) + + # inverse transform + s[A].compute_inline() + k, b, vh, vw = s[Y].op.axis + r_eps, r_nu = s[Y].op.reduce_axis + _ = [s[Y].unroll(x) for x in [vh, vw, r_eps, r_nu]] + tile_and_bind(s, Y, k, b, 4, 1) + + # schedule output + if output.op in s.outputs: # no bias + output = output + else: # has bias + s[output].compute_inline() + output = s.outputs[0] + + _, k, h, w = s[output].op.axis + tile_and_bind3d(s, output, k, h, w, 1, 2, 2)