Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Feb 13, 2018
1 parent 360de22 commit 8bac8ee
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions topi/python/topi/mali/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def decl_conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype='float32

gemm_factor = 4

if (kernel_shape[2:4] == (3, 3) and (HPAD, WPAD) == (1, 1) and kernel_shape[0] >= 64
and data_shape[2] * data_shape[3] // 4 % gemm_factor == 0 and (HSTR, WSTR) == (1, 1)):
if (kernel_shape[2:4] == (3, 3) and (HPAD, WPAD) == (1, 1) and kernel_shape[0] >= 64 and
data_shape[2] * data_shape[3] // 4 % gemm_factor == 0 and (HSTR, WSTR) == (1, 1)):
return _decl_winograd(data, kernel, stride, padding, layout, out_dtype)
elif kernel_shape[2:4] == (1, 1):
return _decl_im2col(data, kernel, stride, padding, layout, out_dtype)
Expand Down Expand Up @@ -559,42 +559,43 @@ def _decl_winograd(data, kernel, stride, padding, layout, out_dtype):
# pack input tile
input_tile = tvm.compute((C, P // bnb, alpha, alpha, bnb),
lambda c, b, eps, nu, bb:
data_pad[(b*bnb+bb) // (nH*nW)][c][(b*bnb+bb) // nW % nH * m + eps][(b*bnb+bb) % nW * m + nu],
data_pad[(b*bnb+bb) // (nH*nW)][c][(b*bnb+bb) // nW % nH * m + eps]
[(b*bnb+bb) % nW * m + nu],
name='d')

# transform kernel
G = const_array(G_data, 'G')
r_kh = tvm.reduce_axis((0, KH), 'r_kh')
r_kw = tvm.reduce_axis((0, KW), 'r_kw')
U = tvm.compute((alpha, alpha, K // bna, C, bna), lambda eps, nu, k, c, kk:
tvm.sum(kernel[k * bna + kk][c][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]),
name='U')
tvm.sum(kernel[k * bna + kk][c][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw],
axis=[r_kh, r_kw]), name='U')

# transform image
B = const_array(B_data, 'B')
r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
V = tvm.compute((alpha, alpha, P // bnb, C, bnb), lambda eps, nu, b, c, bb:
tvm.sum(input_tile[c][b][r_eps][r_nu][bb] * B[r_eps][eps] * B[r_nu][nu], axis=[r_eps, r_nu]),
name='V')
tvm.sum(input_tile[c][b][r_eps][r_nu][bb] * B[r_eps][eps] * B[r_nu][nu],
axis=[r_eps, r_nu]), name='V')

# batch gemm
c = tvm.reduce_axis((0, C), name='c')
M = tvm.compute((alpha, alpha, K, P), lambda eps, nu, k, b:
tvm.sum(U[eps][nu][k // bna][c][k % bna] *
V[eps][nu][b // bnb][c][b % bnb], axis=c), name='M')
tvm.sum(U[eps][nu][k // bna][c][k % bna] *
V[eps][nu][b // bnb][c][b % bnb], axis=c), name='M')

# inverse transform
A = const_array(A_data, 'A')
r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
Y = tvm.compute((K, P, m, m), lambda k, b, vh, vw:
tvm.sum(M[r_eps][r_nu][k][b] * A[r_eps][vh] * A[r_nu][vw], axis=[r_eps, r_nu]),
name='Y')
tvm.sum(M[r_eps][r_nu][k][b] * A[r_eps][vh] * A[r_nu][vw],
axis=[r_eps, r_nu]), name='Y')

# unpack output
output = tvm.compute((N, K, H, W), lambda n, k, h, w:
Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m],
Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m],
name='output', tag='winograd_conv_output')

return output
Expand Down Expand Up @@ -628,7 +629,7 @@ def _schedule_winograd(s, op):
eps, nu, k, c, kk, = s[U].op.axis
r_kh, r_kw = s[U].op.reduce_axis
s[U].reorder(k, c, kk, eps, nu, r_kh, r_kw)
[s[U].unroll(x) for x in [eps, nu, r_kh, r_kw]]
_ = [s[U].unroll(x) for x in [eps, nu, r_kh, r_kw]]
s[U].vectorize(kk)
tile_and_bind(s, U, k, c, 1, 256)

Expand All @@ -637,7 +638,7 @@ def _schedule_winograd(s, op):
eps, nu, b, c, bb = s[V].op.axis
r_eps, r_nu = s[V].op.reduce_axis
s[V].reorder(b, c, bb, eps, nu, r_nu, r_eps)
[s[V].unroll(x) for x in [eps, nu, r_eps, r_nu]]
_ = [s[V].unroll(x) for x in [eps, nu, r_eps, r_nu]]
s[V].vectorize(bb)
tile_and_bind(s, V, b, c, 2, 1)

Expand All @@ -661,7 +662,7 @@ def _schedule_winograd(s, op):
s[A].compute_inline()
k, b, vh, vw = s[Y].op.axis
r_eps, r_nu = s[Y].op.reduce_axis
[s[Y].unroll(x) for x in [vh, vw, r_eps, r_nu]]
_ = [s[Y].unroll(x) for x in [vh, vw, r_eps, r_nu]]
tile_and_bind(s, Y, k, b, 4, 1)

# schedule output
Expand Down

0 comments on commit 8bac8ee

Please sign in to comment.