Skip to content

Commit

Permalink
Merge pull request apache#7 from cmu-catalyst/n_to_1_lowering
Browse files Browse the repository at this point in the history
N to 1 lowering
  • Loading branch information
MadFunMaker committed Apr 21, 2021
2 parents 9a76b6f + 97656f6 commit 465ff28
Show file tree
Hide file tree
Showing 27 changed files with 2,030 additions and 155 deletions.
131 changes: 112 additions & 19 deletions python/tvm/contrib/cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,9 +470,8 @@ def softmax(x, axis=-1):
name="y",
)

#inputs[0], pool_size, strides, padding, "max", ceil_mode, data_layout, True)

def max_pool2d(x, pool_size, strides, padding, pool_type, ceil_mode, data_layout, count_include_pad):
def maxpool2d(x, pool_size, strides, padding, pool_type, ceil_mode, data_layout, count_include_pad):
"""Compute softmax using CuDNN
Parameters
Expand Down Expand Up @@ -501,21 +500,30 @@ def max_pool2d(x, pool_size, strides, padding, pool_type, ceil_mode, data_layout
int horizontalStride = args[11];
"""
print("Python cudnn.py pool2d!!", file=sys.stderr)
dims = 2
data_shape = x.shape
output_shape = list(data_shape)
#outputDim = 1 + (inputDim + 2*padding - windowDim)/poolingStride;
for i in range(dims):
output_shape[i+2] = int(1 + tvm.tir.div((data_shape[i+2] + 2*padding[i]-pool_size[i]),strides[i]))

return te.extern(
x.shape,
output_shape,
[x],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cudnn.pooling.forward", ins[0], outs[0],
1.0, 0.0, 0, 0, pool_size[0], pool_size[1], padding[0], padding[1],
1, 0, # alpha, beta
3, # MODE: CUDNN_POOLING_MAX_DETERMINISTIC
0, # CUDNN_NOT_PROPAGATE_NAN
pool_size[0], pool_size[1],
padding[0], padding[1],
strides[0], strides[1]
),
name="y",
)


def relu(x):
print("Python cudnn.py relu!!", file=sys.stderr)
return te.extern(
x.shape,
[x],
Expand All @@ -526,28 +534,113 @@ def relu(x):
name="y",
)

def bias_add(data, bias, axis):
print("Python cudnn.py bias_add!!", file=sys.stderr)
def biasadd(data, bias, axis):
return te.extern(
data.shape,
[bias, data],
[bias], #inputs
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cudnn.add", ins[0], outs[0],
1, 1
1, 1, axis
),
#out_buffers = [data],
name="y",
)



def conv_bias_activation_forward(data, ):
return te.extern(
oshape,
[x, w],
def prepare(x, w, pad, stride, dilation, tensor_format, algo, conv_dtype, groups):
dims = len(x.shape)
assert dims in (4, 5)

lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cudnn.conv2d+bias+activation.forward",
),
name="y",
)
conv_dtype = x.dtype if conv_dtype is None else conv_dtype
pad, stride, dilation, _, _ = _prepare_global_func_params(dims - 2, pad, stride, dilation)

x_shape = list(x.shape)

if isinstance(x.shape[0], tvm.tir.expr.IntImm):
oshape = conv_output_shape(
tensor_format,
pad,
stride,
dilation,
x_shape,
list(w.shape),
x.dtype,
conv_dtype,
groups,
)
if algo == -1:
# For now if we try to call `cudnnFindConvolutionForwardAlgorithm` when
# using INT8 data type, CuDNN will crash down.
# On the other hand, CuDNN only support IMPLICIT_PRECOMP_GEMM at NHWC format
if tensor_format == 1 and conv_dtype == "int32":
algo = 1
else:
algo = conv_find_algo(
tensor_format,
pad,
stride,
dilation,
list(x.shape),
list(w.shape),
oshape,
x.dtype,
conv_dtype,
groups,
)
else:
# The dynamic batch size case, pretend this is a single batch
x_shape[0] = 1
oshape = conv_output_shape(
tensor_format,
pad,
stride,
dilation,
x_shape,
list(w.shape),
x.dtype,
conv_dtype,
groups,
)
oshape[0] = x.shape[0]
# This picks CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
# It seems this is the fastest among algorithms that are always applicable
algo = 1

return dims, conv_dtype, pad, stride, dilation, oshape, algo



def conv2d_biasadd_relu(x, w, z, b, pad, stride, dilation, conv_mode, tensor_format, algo, conv_dtype, activ_mode, nan_prop_mode, actv_coeff, groups=1):

dims, conv_dtype, pad, stride, dilation, oshape, algo = \
prepare(x, w, pad, stride, dilation, tensor_format, algo, conv_dtype, groups)

return te.extern(
oshape,
[x, w, z, b],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cudnn.conv2d+bias+activation.forward",
conv_mode, # mode: CUDNN_CONVOLUTION
tensor_format, # CUDNN_TENSOR_NCHW
algo,
pad[0], pad[1],
stride[0], stride[1],
dilation[0], dilation[1],
conv_dtype,
ins[0], # x
ins[1], # w
ins[2], # z
ins[3], # bias
outs[0], # y
groups,
1,#alphas[0],
0,#alphas[1],
1,#alphas[0] for z
0,
activ_mode,
nan_prop_mode,
actv_coeff
),
name="y",
)
177 changes: 177 additions & 0 deletions python/tvm/relay/backend/compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@
from .. import ty as _ty
from . import _backend

# for target-specific lowering
from tvm.relay.op import op as _op
#from tvm.relay.analysis import post_order_visit
from tvm import relay
from tvm import topi
from tvm.relay.op.strategy.generic import *
from tvm import te
from tvm.contrib.cudnn import softmax

logger = logging.getLogger("compile_engine")
autotvm_logger = logging.getLogger("autotvm")

Expand Down Expand Up @@ -186,6 +195,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True)
ret : tuple(relay.op.OpImplementation, List[tvm.te.Tensor])
The best op implementation and the corresponding output tensors.
"""

all_impls = get_valid_implementations(op, attrs, inputs, out_type, target)
best_plevel_impl = max(all_impls, key=lambda x: x.plevel)

Expand Down Expand Up @@ -263,6 +273,173 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True)
return best_plevel_impl, outputs[best_plevel_impl]


@tvm._ffi.register_func("relay.backend.target_specific_lowering")
def target_specific_lowering(func, inputMap, target_info=None):

import sys
#print("\t[Compile_engine.py] Custom lowering?", file=sys.stderr)

# Eventually, we want to define custom implemenation
# However, currently, we do not know how to do it.
# So, for now, let's try the hacky way.

strategy = _op.OpStrategy()
# relay express, callback
#relay.analysis.post_order_visit(mod['main'], lambda expr: log_backend_op_perf(b_op_lib, expr, target))
#inputs = relay.analysis.free_vars(func.body)

calls = []
def extract_attr(expr, calls):
if type(expr) == tvm.relay.expr.Call:
calls.append(expr)
relay.analysis.post_order_visit(func, lambda expr: extract_attr(expr, calls))

tokens = target_info.split('_')
target = tokens[0]
pattern = tokens[1]

def collect_input(inputMap):
inputs = []
for key, varray in inputMap.items():
for val in varray:
inputs.append(val)
return inputs

attrs, ret_type = None, None
if target == "cudnn":
if pattern == "softmax":
strategy.add_implementation(
wrap_custom_compute_softmax(topi.cuda.softmax_cudnn),
wrap_topi_schedule(topi.generic.schedule_extern),
name="softmax.cudnn",
)
# has single op
attrs = calls[0].attrs
ret_type = calls[0].checked_type
inputs = collect_input(inputMap)

elif pattern == "relu":
strategy.add_implementation(
wrap_custom_compute_relu(topi.cuda.relu_cudnn),
wrap_topi_schedule(topi.generic.schedule_extern),
name="relu.cudnn",
)
# has single op
attrs = calls[0].attrs
ret_type = calls[0].checked_type
inputs = collect_input(inputMap)

# TODO: not supported yet
elif pattern == "biasadd":
strategy.add_implementation(
wrap_custom_compute_biasadd(topi.cuda.biasadd_cudnn),
wrap_topi_schedule(topi.generic.schedule_extern),
name="biasadd.cudnn",
)
# has single op
attrs = calls[0].attrs
ret_type = calls[0].checked_type
inputs = collect_input(inputMap)

elif pattern == "conv2d":
strategy.add_implementation(
wrap_custom_compute_conv2d(
topi.cuda.conv2d_cudnn, need_data_layout=True, has_groups=True
),
#wrap_topi_schedule(topi.cuda.schedule_conv2d_cudnn),
wrap_topi_schedule(topi.generic.schedule_extern),
name="conv2d.cudnn",
)
# has single op
attrs = calls[0].attrs
ret_type = calls[0].checked_type
inputs = collect_input(inputMap)

elif pattern == "maxpool2d":
strategy.add_implementation(
wrap_custom_compute_maxpool2d(topi.cuda.maxpool2d_cudnn),
wrap_topi_schedule(topi.generic.schedule_extern),
name="maxpool2d.cudnn",
)
# has single op
attrs = calls[0].attrs
ret_type = calls[0].checked_type
inputs = collect_input(inputMap)

# TODO: not supported yet
elif pattern == "bn":
#strategy.add_implementation(
# wrap_custom_compute_maxpool2d(topi.cuda.maxpool2d_cudnn),
# wrap_topi_schedule(topi.generic.schedule_extern),
# name="bn.cudnn",
#)

# has single op
attrs = calls[0].attrs
ret_type = calls[0].checked_type
inputs = collect_input(inputMap)


# fused ops
# TODO: has correctness issue
elif pattern == "conv2d+biasadd+relu":
strategy.add_implementation(
wrap_custom_compute_conv2d_biasadd_relu(
topi.cuda.conv2d_biasadd_relu_cudnn, need_data_layout=True, has_groups=True
),
wrap_topi_schedule(topi.generic.schedule_extern),
name="conv2d_biasadd_relu.cudnn",
)

data, kernel, Z, bias = None, None, None, None
attrs, ret_type = None, None
for call in calls:
call_name = call.op.name
if "conv2d" in call_name:
attrs = call.attrs
ret_type = call.checked_type
args = call.args
data = inputMap[args[0]]
kernel = inputMap[args[1]]
elif "bias_add" in call_name:
bias = inputMap[args[1]]
elif "relu" in call_name:
Z = inputMap[args[0]]

inputs = [data[0], kernel[0], Z[0], bias[0]]

elif target == "cublas":
if pattern == "dense":
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_cublas),
wrap_topi_schedule(topi.generic.schedule_extern),
name="dense.cublas",
)
# has single op
attrs = calls[0].attrs
ret_type = calls[0].checked_type
inputs = collect_input(inputMap)


# To compute subgraph
# attrs for each op
# input for the subgraph
# - pattern - will be given

# May need rewrite?
#

impl, outputs = None, None
for spec in strategy.specializations:
for impl in spec.implementations:
# attribute, inputs, output_type
outputs = impl.compute(attrs, inputs, ret_type)
return LoweredOutput(outputs, impl)

# Should not reach
return None


@tvm._ffi.register_func("relay.backend.lower_call")
def lower_call(call, inputs, target):
"""Lower the call expression to op implementation and tensor outputs."""
Expand Down
Loading

0 comments on commit 465ff28

Please sign in to comment.