diff --git a/topi/python/topi/mali/conv2d.py b/topi/python/topi/mali/conv2d.py index 4be4161540f8..8f59040b23a6 100644 --- a/topi/python/topi/mali/conv2d.py +++ b/topi/python/topi/mali/conv2d.py @@ -122,8 +122,8 @@ def decl_conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype='float32 gemm_factor = 4 - if (kernel_shape[2:4] == (3, 3) and (HPAD, WPAD) == (1, 1) and kernel_shape[0] >= 64 - and data_shape[2] * data_shape[3] // 4 % gemm_factor == 0 and (HSTR, WSTR) == (1, 1)): + if (kernel_shape[2:4] == (3, 3) and (HPAD, WPAD) == (1, 1) and kernel_shape[0] >= 64 and + data_shape[2] * data_shape[3] // 4 % gemm_factor == 0 and (HSTR, WSTR) == (1, 1)): return _decl_winograd(data, kernel, stride, padding, layout, out_dtype) elif kernel_shape[2:4] == (1, 1): return _decl_im2col(data, kernel, stride, padding, layout, out_dtype) @@ -559,7 +559,8 @@ def _decl_winograd(data, kernel, stride, padding, layout, out_dtype): # pack input tile input_tile = tvm.compute((C, P // bnb, alpha, alpha, bnb), lambda c, b, eps, nu, bb: - data_pad[(b*bnb+bb) // (nH*nW)][c][(b*bnb+bb) // nW % nH * m + eps][(b*bnb+bb) % nW * m + nu], + data_pad[(b*bnb+bb) // (nH*nW)][c][(b*bnb+bb) // nW % nH * m + eps] + [(b*bnb+bb) % nW * m + nu], name='d') # transform kernel @@ -567,34 +568,34 @@ def _decl_winograd(data, kernel, stride, padding, layout, out_dtype): r_kh = tvm.reduce_axis((0, KH), 'r_kh') r_kw = tvm.reduce_axis((0, KW), 'r_kw') U = tvm.compute((alpha, alpha, K // bna, C, bna), lambda eps, nu, k, c, kk: - tvm.sum(kernel[k * bna + kk][c][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), - name='U') + tvm.sum(kernel[k * bna + kk][c][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], + axis=[r_kh, r_kw]), name='U') # transform image B = const_array(B_data, 'B') r_eps = tvm.reduce_axis((0, alpha), 'r_eps') r_nu = tvm.reduce_axis((0, alpha), 'r_nu') V = tvm.compute((alpha, alpha, P // bnb, C, bnb), lambda eps, nu, b, c, bb: - tvm.sum(input_tile[c][b][r_eps][r_nu][bb] * B[r_eps][eps] * B[r_nu][nu], axis=[r_eps, r_nu]), - name='V') + tvm.sum(input_tile[c][b][r_eps][r_nu][bb] * B[r_eps][eps] * B[r_nu][nu], + axis=[r_eps, r_nu]), name='V') # batch gemm c = tvm.reduce_axis((0, C), name='c') M = tvm.compute((alpha, alpha, K, P), lambda eps, nu, k, b: - tvm.sum(U[eps][nu][k // bna][c][k % bna] * - V[eps][nu][b // bnb][c][b % bnb], axis=c), name='M') + tvm.sum(U[eps][nu][k // bna][c][k % bna] * + V[eps][nu][b // bnb][c][b % bnb], axis=c), name='M') # inverse transform A = const_array(A_data, 'A') r_eps = tvm.reduce_axis((0, alpha), 'r_eps') r_nu = tvm.reduce_axis((0, alpha), 'r_nu') Y = tvm.compute((K, P, m, m), lambda k, b, vh, vw: - tvm.sum(M[r_eps][r_nu][k][b] * A[r_eps][vh] * A[r_nu][vw], axis=[r_eps, r_nu]), - name='Y') + tvm.sum(M[r_eps][r_nu][k][b] * A[r_eps][vh] * A[r_nu][vw], + axis=[r_eps, r_nu]), name='Y') # unpack output output = tvm.compute((N, K, H, W), lambda n, k, h, w: - Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m], + Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m], name='output', tag='winograd_conv_output') return output @@ -628,7 +629,7 @@ def _schedule_winograd(s, op): eps, nu, k, c, kk, = s[U].op.axis r_kh, r_kw = s[U].op.reduce_axis s[U].reorder(k, c, kk, eps, nu, r_kh, r_kw) - [s[U].unroll(x) for x in [eps, nu, r_kh, r_kw]] + _ = [s[U].unroll(x) for x in [eps, nu, r_kh, r_kw]] s[U].vectorize(kk) tile_and_bind(s, U, k, c, 1, 256) @@ -637,7 +638,7 @@ def _schedule_winograd(s, op): eps, nu, b, c, bb = s[V].op.axis r_eps, r_nu = s[V].op.reduce_axis s[V].reorder(b, c, bb, eps, nu, r_nu, r_eps) - [s[V].unroll(x) for x in [eps, nu, r_eps, r_nu]] + _ = [s[V].unroll(x) for x in [eps, nu, r_eps, r_nu]] s[V].vectorize(bb) tile_and_bind(s, V, b, c, 2, 1) @@ -661,7 +662,7 @@ def _schedule_winograd(s, op): s[A].compute_inline() k, b, vh, vw = s[Y].op.axis r_eps, r_nu = s[Y].op.reduce_axis - [s[Y].unroll(x) for x in [vh, vw, r_eps, r_nu]] + _ = [s[Y].unroll(x) for x in [vh, vw, r_eps, r_nu]] tile_and_bind(s, Y, k, b, 4, 1) # schedule output