Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Add winograd for mali #898

Merged
merged 4 commits into from
Feb 13, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 212 additions & 13 deletions topi/python/topi/mali/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"""conv2d schedule on ARM Mali GPU"""

from __future__ import absolute_import as _abs

import numpy as np
import tvm

from .. import generic
Expand Down Expand Up @@ -63,7 +65,23 @@ def transpose(s, tensor, readers):
s[tmp].compute_inline()
return s.cache_write(tmp, "global"), tmp

@conv2d.register("mali")
def const_array(data, name):
""" convert an const array to tvm tensor"""
row, col = data.shape
dtype = str(data.dtype)

def select_array(i, j):
now = tvm.const(0.0, dtype)
for ii in range(row):
for jj in range(col):
now = tvm.select(tvm.all(i % row == ii, j % col == jj),
tvm.const(data[ii][jj], dtype),
now)
return now
return tvm.compute(data.shape, select_array, name=name)


@conv2d.register(["mali"])
def decl_conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'):
"""Conv2D operator for ARM Mali GPU backend.

Expand Down Expand Up @@ -94,10 +112,20 @@ def decl_conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype='float32
assert data.dtype == kernel.dtype, "Do not support inputs with different data types now."

out_dtype = data.dtype
if util.get_const_int(kernel.shape[2]) == 1:
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
kernel_shape = util.get_const_tuple(kernel.shape)
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride

if (kernel_shape[2:4] == (3, 3) and (HPAD, WPAD) == (1, 1) and kernel_shape[0] >= 64 and
(HSTR, WSTR) == (1, 1)):
return _decl_winograd(data, kernel, stride, padding, layout, out_dtype)
elif kernel_shape[2:4] == (1, 1):
return _decl_im2col(data, kernel, stride, padding, layout, out_dtype)
else:
return _decl_direct(data, kernel, stride, padding, layout, out_dtype)
return _decl_spatialpack(data, kernel, stride, padding, layout, out_dtype)

@generic.schedule_conv2d_nchw.register(["mali"])
def schedule_conv2d_nchw(outs):
Expand Down Expand Up @@ -129,14 +157,17 @@ def traverse(op):
if 'im2col_conv_output' in op.tag:
_schedule_im2col_conv2d(s, op)

if 'direct_conv_output' in op.tag:
_schedule_direct_conv2d(s, op)
if 'spatialpack_conv_output' in op.tag:
_schedule_spatialpack_conv2d(s, op)

if 'winograd_conv_output' in op.tag:
_schedule_winograd(s, op)

traverse(outs[0].op)
return s

def _decl_direct(data, kernel, stride, padding, layout, out_dtype):
"""declare the direct method (spatial packing) for conv2d"""
def _decl_spatialpack(data, kernel, stride, padding, layout, out_dtype):
"""declare the spatialpack method (spatial packing) for conv2d"""
_, CI, IH, IW = [util.get_const_int(x) for x in data.shape]
CO, _, KH, KW = [util.get_const_int(x) for x in kernel.shape]
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
Expand Down Expand Up @@ -207,12 +238,12 @@ def _decl_direct(data, kernel, stride, padding, layout, out_dtype):

output = tvm.compute(oshape, lambda n, co, h, w:
conv[n][co//VC][h/VH][w//VW][h%VH][w%VW][co%VC],
name='output_unpack', tag='direct_conv_output')
name='output_unpack', tag='spatialpack_conv_output')

return output

def _schedule_direct_conv2d(s, op):
"""schedule the direct method (spatial packing) for conv2d"""
def _schedule_spatialpack_conv2d(s, op):
"""schedule the spatialpack method (spatial packing) for conv2d"""
# get ops and tensors
output = op.output(0)
output_height = util.get_const_int(output.shape[2])
Expand Down Expand Up @@ -294,8 +325,6 @@ def _schedule_direct_conv2d(s, op):
_, co, oh, ow = s[output].op.axis
tile_and_bind3d(s, output, co, oh, ow, num_thread, 1, last)

#print(tvm.lower(s, [data, kernel, output], simple_mode=True))

def _decl_im2col(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'):
"""declare the Im2Col method for conv2d"""
_, CI, IH, IW = [x.value for x in data.shape]
Expand Down Expand Up @@ -476,4 +505,174 @@ def _schedule_im2col_conv2d(s, op):
s[output].vectorize(vw)
fuse_and_bind(s, output, [n, co, h, w])

#print(tvm.lower(s, [data, kernel], simple_mode=True))
def _decl_winograd(data, kernel, stride, padding, layout, out_dtype):
"""declare winograd fast convolution F(2x2, 3x3) for conv2d"""
N, CI, H, W = [util.get_const_int(x) for x in data.shape]
CO, CI, KH, KW = [util.get_const_int(x) for x in kernel.shape]
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride

assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 and KH == 3 and KW == 3
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")

B_data = np.array([
[1, 0, 0, 0],
[0, 1, -1, 1],
[-1, 1, 1, 0],
[0, 0, 0, -1]
], out_dtype)

G_data = np.array([
[1, 0, 0],
[1.0/2, 1.0/2, 1.0/2],
[1.0/2, -1.0/2, 1.0/2],
[0, 0, 1],
], out_dtype)

A_data = np.array([
[1, 0],
[1, 1],
[1, -1],
[0, -1],
], out_dtype)

m = 2
r = 3
alpha = m + r - 1
K = CO
C = CI

nH, nW = (H + m-1) // m, (W + m-1) // m
P = N * nH * nW

bna, bnb = 4, 4
if data.dtype == 'float16':
bnb *= 2
P_round = (P + bnb - 1) // bnb * bnb
assert K % bna == 0 and P_round % bnb == 0

# pack input tile
input_tile = tvm.compute((C, P_round // bnb, alpha, alpha, bnb),
lambda c, b, eps, nu, bb:
tvm.select(b * bnb + bb < P,\
data_pad[(b*bnb+bb) // (nH*nW)][c][(b*bnb+bb) // nW % nH * m + eps]\
[(b*bnb+bb) % nW * m + nu], tvm.const(0, data_pad.dtype)),
name='d')

# transform kernel
G = const_array(G_data, 'G')
r_kh = tvm.reduce_axis((0, KH), 'r_kh')
r_kw = tvm.reduce_axis((0, KW), 'r_kw')
U = tvm.compute((alpha, alpha, K // bna, C, bna), lambda eps, nu, k, c, kk:
tvm.sum(kernel[k * bna + kk][c][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw],
axis=[r_kh, r_kw]), name='U')

# transform image
B = const_array(B_data, 'B')
r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
V = tvm.compute((alpha, alpha, P_round // bnb, C, bnb), lambda eps, nu, b, c, bb:
tvm.sum(input_tile[c][b][r_eps][r_nu][bb] * B[r_eps][eps] * B[r_nu][nu],
axis=[r_eps, r_nu]), name='V')

# batch gemm
c = tvm.reduce_axis((0, C), name='c')
M = tvm.compute((alpha, alpha, K, P_round), lambda eps, nu, k, b:
tvm.sum(U[eps][nu][k // bna][c][k % bna] *
V[eps][nu][b // bnb][c][b % bnb], axis=c), name='M')

# inverse transform
A = const_array(A_data, 'A')
r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
Y = tvm.compute((K, P, m, m), lambda k, b, vh, vw:
tvm.sum(M[r_eps][r_nu][k][b] * A[r_eps][vh] * A[r_nu][vw],
axis=[r_eps, r_nu]), name='Y')

# unpack output
output = tvm.compute((N, K, H, W), lambda n, k, h, w:
Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m]
# thw following term is used to make the padding effective,
# otherwise the padding will be eliminated by bound inference
+ tvm.const(0, out_dtype) * M[alpha-1][alpha-1][K-1][P_round-1],
name='output', tag='winograd_conv_output')

return output

def _schedule_winograd(s, op):
"""schedule winograd fast convolution F(2x2, 3x3) for conv2d"""

# get ops and tensors
output = op.output(0)

Y = op.input_tensors[0]
M, A = s[Y].op.input_tensors
U, V = s[M].op.input_tensors
kernel, G = s[U].op.input_tensors
d, B = s[V].op.input_tensors
data_pad = s[d].op.input_tensors[0]
data = s[data_pad].op.input_tensors[0]

# padding
s[data_pad].compute_inline()

# pack input tiles
c, b, eps, nu, bb = s[d].op.axis
s[d].reorder(eps, nu, bb)
aha = s[d].fuse(eps, nu)
s[d].unroll(bb)
tile_and_bind3d(s, d, c, b, aha, 4, 1, 1)

# transform kernel
s[G].compute_inline()
eps, nu, k, c, kk, = s[U].op.axis
r_kh, r_kw = s[U].op.reduce_axis
s[U].reorder(k, c, kk, eps, nu, r_kh, r_kw)
_ = [s[U].unroll(x) for x in [eps, nu, r_kh, r_kw]]
s[U].vectorize(kk)
tile_and_bind(s, U, k, c, 1, 256)

# transform image
s[B].compute_inline()
eps, nu, b, c, bb = s[V].op.axis
r_eps, r_nu = s[V].op.reduce_axis
s[V].reorder(b, c, bb, eps, nu, r_nu, r_eps)
_ = [s[V].unroll(x) for x in [eps, nu, r_eps, r_nu]]
s[V].vectorize(bb)
tile_and_bind(s, V, b, c, 2, 1)

# batch gemm
bna, bnb = 4, 4
if data.dtype == 'float16':
bnb *= 2

eps, nu, k, b = s[M].op.axis
c = s[M].op.reduce_axis[0]
yo, xo, yi, xi = s[M].tile(k, b, bna, bnb)
s[M].reorder(c, yi, xi)
c, c_unroll = s[M].split(c, 2)
s[M].unroll(c_unroll)
s[M].unroll(yi)
s[M].vectorize(xi)
z = s[M].fuse(eps, nu)
tile_and_bind3d(s, M, z, yo, xo, 1, 8, 1)

# inverse transform
s[A].compute_inline()
k, b, vh, vw = s[Y].op.axis
r_eps, r_nu = s[Y].op.reduce_axis
_ = [s[Y].unroll(x) for x in [vh, vw, r_eps, r_nu]]
tile_and_bind(s, Y, k, b, 4, 1)

# schedule output
if output.op in s.outputs: # no bias
output = output
else: # has bias
s[output].compute_inline()
output = s.outputs[0]

_, k, h, w = s[output].op.axis
tile_and_bind3d(s, output, k, h, w, 1, 2, 2)