diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index e748dc77861a..504ad8527368 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -470,9 +470,8 @@ def softmax(x, axis=-1): name="y", ) -#inputs[0], pool_size, strides, padding, "max", ceil_mode, data_layout, True) -def max_pool2d(x, pool_size, strides, padding, pool_type, ceil_mode, data_layout, count_include_pad): +def maxpool2d(x, pool_size, strides, padding, pool_type, ceil_mode, data_layout, count_include_pad): """Compute softmax using CuDNN Parameters @@ -501,13 +500,23 @@ def max_pool2d(x, pool_size, strides, padding, pool_type, ceil_mode, data_layout int horizontalStride = args[11]; """ - print("Python cudnn.py pool2d!!", file=sys.stderr) + dims = 2 + data_shape = x.shape + output_shape = list(data_shape) + #outputDim = 1 + (inputDim + 2*padding - windowDim)/poolingStride; + for i in range(dims): + output_shape[i+2] = int(1 + tvm.tir.div((data_shape[i+2] + 2*padding[i]-pool_size[i]),strides[i])) + return te.extern( - x.shape, + output_shape, [x], lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.cudnn.pooling.forward", ins[0], outs[0], - 1.0, 0.0, 0, 0, pool_size[0], pool_size[1], padding[0], padding[1], + 1, 0, # alpha, beta + 3, # MODE: CUDNN_POOLING_MAX_DETERMINISTIC + 0, # CUDNN_NOT_PROPAGATE_NAN + pool_size[0], pool_size[1], + padding[0], padding[1], strides[0], strides[1] ), name="y", @@ -515,7 +524,6 @@ def max_pool2d(x, pool_size, strides, padding, pool_type, ceil_mode, data_layout def relu(x): - print("Python cudnn.py relu!!", file=sys.stderr) return te.extern( x.shape, [x], @@ -526,28 +534,113 @@ def relu(x): name="y", ) -def bias_add(data, bias, axis): - print("Python cudnn.py bias_add!!", file=sys.stderr) +def biasadd(data, bias, axis): return te.extern( data.shape, - [bias, data], + [bias], #inputs lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.cudnn.add", ins[0], outs[0], - 1, 1 + 1, 1, axis ), + #out_buffers = [data], name="y", ) -def conv_bias_activation_forward(data, ): - return te.extern( - oshape, - [x, w], +def prepare(x, w, pad, stride, dilation, tensor_format, algo, conv_dtype, groups): + dims = len(x.shape) + assert dims in (4, 5) - lambda ins, outs: tvm.tir.call_packed( - "tvm.contrib.cudnn.conv2d+bias+activation.forward", - ), - name="y", - ) + conv_dtype = x.dtype if conv_dtype is None else conv_dtype + pad, stride, dilation, _, _ = _prepare_global_func_params(dims - 2, pad, stride, dilation) + x_shape = list(x.shape) + + if isinstance(x.shape[0], tvm.tir.expr.IntImm): + oshape = conv_output_shape( + tensor_format, + pad, + stride, + dilation, + x_shape, + list(w.shape), + x.dtype, + conv_dtype, + groups, + ) + if algo == -1: + # For now if we try to call `cudnnFindConvolutionForwardAlgorithm` when + # using INT8 data type, CuDNN will crash down. + # On the other hand, CuDNN only support IMPLICIT_PRECOMP_GEMM at NHWC format + if tensor_format == 1 and conv_dtype == "int32": + algo = 1 + else: + algo = conv_find_algo( + tensor_format, + pad, + stride, + dilation, + list(x.shape), + list(w.shape), + oshape, + x.dtype, + conv_dtype, + groups, + ) + else: + # The dynamic batch size case, pretend this is a single batch + x_shape[0] = 1 + oshape = conv_output_shape( + tensor_format, + pad, + stride, + dilation, + x_shape, + list(w.shape), + x.dtype, + conv_dtype, + groups, + ) + oshape[0] = x.shape[0] + # This picks CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM + # It seems this is the fastest among algorithms that are always applicable + algo = 1 + + return dims, conv_dtype, pad, stride, dilation, oshape, algo + + + +def conv2d_biasadd_relu(x, w, z, b, pad, stride, dilation, conv_mode, tensor_format, algo, conv_dtype, activ_mode, nan_prop_mode, actv_coeff, groups=1): + + dims, conv_dtype, pad, stride, dilation, oshape, algo = \ + prepare(x, w, pad, stride, dilation, tensor_format, algo, conv_dtype, groups) + + return te.extern( + oshape, + [x, w, z, b], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.cudnn.conv2d+bias+activation.forward", + conv_mode, # mode: CUDNN_CONVOLUTION + tensor_format, # CUDNN_TENSOR_NCHW + algo, + pad[0], pad[1], + stride[0], stride[1], + dilation[0], dilation[1], + conv_dtype, + ins[0], # x + ins[1], # w + ins[2], # z + ins[3], # bias + outs[0], # y + groups, + 1,#alphas[0], + 0,#alphas[1], + 1,#alphas[0] for z + 0, + activ_mode, + nan_prop_mode, + actv_coeff + ), + name="y", + ) diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 68397cc0cef6..7af29c6d6ba7 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -30,6 +30,15 @@ from .. import ty as _ty from . import _backend +# for target-specific lowering +from tvm.relay.op import op as _op +#from tvm.relay.analysis import post_order_visit +from tvm import relay +from tvm import topi +from tvm.relay.op.strategy.generic import * +from tvm import te +from tvm.contrib.cudnn import softmax + logger = logging.getLogger("compile_engine") autotvm_logger = logging.getLogger("autotvm") @@ -186,6 +195,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) ret : tuple(relay.op.OpImplementation, List[tvm.te.Tensor]) The best op implementation and the corresponding output tensors. """ + all_impls = get_valid_implementations(op, attrs, inputs, out_type, target) best_plevel_impl = max(all_impls, key=lambda x: x.plevel) @@ -263,6 +273,173 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) return best_plevel_impl, outputs[best_plevel_impl] +@tvm._ffi.register_func("relay.backend.target_specific_lowering") +def target_specific_lowering(func, inputMap, target_info=None): + + import sys + #print("\t[Compile_engine.py] Custom lowering?", file=sys.stderr) + + # Eventually, we want to define custom implemenation + # However, currently, we do not know how to do it. + # So, for now, let's try the hacky way. + + strategy = _op.OpStrategy() + # relay express, callback + #relay.analysis.post_order_visit(mod['main'], lambda expr: log_backend_op_perf(b_op_lib, expr, target)) + #inputs = relay.analysis.free_vars(func.body) + + calls = [] + def extract_attr(expr, calls): + if type(expr) == tvm.relay.expr.Call: + calls.append(expr) + relay.analysis.post_order_visit(func, lambda expr: extract_attr(expr, calls)) + + tokens = target_info.split('_') + target = tokens[0] + pattern = tokens[1] + + def collect_input(inputMap): + inputs = [] + for key, varray in inputMap.items(): + for val in varray: + inputs.append(val) + return inputs + + attrs, ret_type = None, None + if target == "cudnn": + if pattern == "softmax": + strategy.add_implementation( + wrap_custom_compute_softmax(topi.cuda.softmax_cudnn), + wrap_topi_schedule(topi.generic.schedule_extern), + name="softmax.cudnn", + ) + # has single op + attrs = calls[0].attrs + ret_type = calls[0].checked_type + inputs = collect_input(inputMap) + + elif pattern == "relu": + strategy.add_implementation( + wrap_custom_compute_relu(topi.cuda.relu_cudnn), + wrap_topi_schedule(topi.generic.schedule_extern), + name="relu.cudnn", + ) + # has single op + attrs = calls[0].attrs + ret_type = calls[0].checked_type + inputs = collect_input(inputMap) + + # TODO: not supported yet + elif pattern == "biasadd": + strategy.add_implementation( + wrap_custom_compute_biasadd(topi.cuda.biasadd_cudnn), + wrap_topi_schedule(topi.generic.schedule_extern), + name="biasadd.cudnn", + ) + # has single op + attrs = calls[0].attrs + ret_type = calls[0].checked_type + inputs = collect_input(inputMap) + + elif pattern == "conv2d": + strategy.add_implementation( + wrap_custom_compute_conv2d( + topi.cuda.conv2d_cudnn, need_data_layout=True, has_groups=True + ), + #wrap_topi_schedule(topi.cuda.schedule_conv2d_cudnn), + wrap_topi_schedule(topi.generic.schedule_extern), + name="conv2d.cudnn", + ) + # has single op + attrs = calls[0].attrs + ret_type = calls[0].checked_type + inputs = collect_input(inputMap) + + elif pattern == "maxpool2d": + strategy.add_implementation( + wrap_custom_compute_maxpool2d(topi.cuda.maxpool2d_cudnn), + wrap_topi_schedule(topi.generic.schedule_extern), + name="maxpool2d.cudnn", + ) + # has single op + attrs = calls[0].attrs + ret_type = calls[0].checked_type + inputs = collect_input(inputMap) + + # TODO: not supported yet + elif pattern == "bn": + #strategy.add_implementation( + # wrap_custom_compute_maxpool2d(topi.cuda.maxpool2d_cudnn), + # wrap_topi_schedule(topi.generic.schedule_extern), + # name="bn.cudnn", + #) + + # has single op + attrs = calls[0].attrs + ret_type = calls[0].checked_type + inputs = collect_input(inputMap) + + + # fused ops + # TODO: has correctness issue + elif pattern == "conv2d+biasadd+relu": + strategy.add_implementation( + wrap_custom_compute_conv2d_biasadd_relu( + topi.cuda.conv2d_biasadd_relu_cudnn, need_data_layout=True, has_groups=True + ), + wrap_topi_schedule(topi.generic.schedule_extern), + name="conv2d_biasadd_relu.cudnn", + ) + + data, kernel, Z, bias = None, None, None, None + attrs, ret_type = None, None + for call in calls: + call_name = call.op.name + if "conv2d" in call_name: + attrs = call.attrs + ret_type = call.checked_type + args = call.args + data = inputMap[args[0]] + kernel = inputMap[args[1]] + elif "bias_add" in call_name: + bias = inputMap[args[1]] + elif "relu" in call_name: + Z = inputMap[args[0]] + + inputs = [data[0], kernel[0], Z[0], bias[0]] + + elif target == "cublas": + if pattern == "dense": + strategy.add_implementation( + wrap_compute_dense(topi.cuda.dense_cublas), + wrap_topi_schedule(topi.generic.schedule_extern), + name="dense.cublas", + ) + # has single op + attrs = calls[0].attrs + ret_type = calls[0].checked_type + inputs = collect_input(inputMap) + + + # To compute subgraph + # attrs for each op + # input for the subgraph + # - pattern - will be given + + # May need rewrite? + # + + impl, outputs = None, None + for spec in strategy.specializations: + for impl in spec.implementations: + # attribute, inputs, output_type + outputs = impl.compute(attrs, inputs, ret_type) + return LoweredOutput(outputs, impl) + + # Should not reach + return None + + @tvm._ffi.register_func("relay.backend.lower_call") def lower_call(call, inputs, target): """Lower the call expression to op implementation and tensor outputs.""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 322a3607904f..9691d518bd8c 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -135,6 +135,7 @@ def schedule_adaptive_pool(attrs, outs, target): return topi.generic.schedule_adaptive_pool(outs) + # softmax def wrap_compute_softmax(topi_compute): """Wrap softmax topi compute""" @@ -146,6 +147,125 @@ def _compute_softmax(attrs, inputs, out_type): return _compute_softmax +# sung: softmax +def wrap_custom_compute_softmax(topi_compute): + """Wrap softmax topi compute""" + + def _compute_softmax(attrs, inputs, out_type): + axis = attrs.get_int("axis") + return [topi_compute(inputs[0], axis)] + + return _compute_softmax + + + +# sung: pooling +def wrap_custom_compute_maxpool2d(topi_compute): + """Wrap pooling topi compute""" + + def _compute_pool(attrs, inputs, out_type): + #data, kernel, stride, padding, pool_type, ceil_mode=False, layout="NCHW", count_include_pad=True + pool_size = get_const_tuple(attrs.pool_size) + padding = get_const_tuple(attrs.padding) + padding = get_const_tuple(attrs.padding) + strides = get_const_tuple(attrs.strides) + data_layout = attrs.get_str("layout") + ceil_mode = attrs.ceil_mode + + return [topi_compute(inputs[0], pool_size, strides, padding, "max", ceil_mode, data_layout, True)] + + return _compute_pool + + +# sung: relu +def wrap_custom_compute_relu(topi_compute): + """Wrap relu topi compute""" + + def _compute_relu(attrs, inputs, out_type): + return [topi_compute(inputs[0])] + + return _compute_relu + + +# sung: biasadd +def wrap_custom_compute_biasadd(topi_compute): + """Wrap bias add topi compute""" + + def _compute_biasadd(attrs, inputs, out_type): + axis = attrs.get_int("axis") + return [topi_compute(inputs[0], inputs[1], axis)] + + return _compute_biasadd + + +# sung: conv2d +def wrap_custom_compute_conv2d( + topi_compute, + need_data_layout=False, + need_out_layout=False, + has_groups=False, + need_auto_scheduler_layout=False, +): + """Wrap conv2d topi compute""" + + def _compute_conv2d(attrs, inputs, out_type): + padding = get_const_tuple(attrs.padding) + strides = get_const_tuple(attrs.strides) + dilation = get_const_tuple(attrs.dilation) + data_layout = attrs.get_str("data_layout") + out_layout = attrs.get_str("out_layout") + out_dtype = attrs.out_dtype + out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype + args = [inputs[0], inputs[1], strides, padding, dilation] + if has_groups: + args.append(attrs.groups) + if need_data_layout: + args.append(data_layout) + if need_out_layout: + args.append(out_layout) + args.append(out_dtype) + if need_auto_scheduler_layout: + args.append(get_auto_scheduler_rewritten_layout(attrs)) + return [topi_compute(*args)] + + return _compute_conv2d + + +def wrap_custom_compute_conv2d_biasadd_relu( + topi_compute, + need_data_layout=False, + need_out_layout=False, + has_groups=False, + need_auto_scheduler_layout=False, +): + """Wrap bias add topi compute""" + def _compute_conv2d_biasadd_relu(attrs, inputs, out_type): + padding = get_const_tuple(attrs.padding) + strides = get_const_tuple(attrs.strides) + dilation = get_const_tuple(attrs.dilation) + data_layout = attrs.get_str("data_layout") + out_layout = attrs.get_str("out_layout") + out_dtype = attrs.out_dtype + out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype + + args = [inputs[0], inputs[1], inputs[2], inputs[3], \ + strides, padding, dilation] + if has_groups: + args.append(attrs.groups) + if need_data_layout: + args.append(data_layout) + if need_out_layout: + args.append(out_layout) + args.append(out_dtype) + if need_auto_scheduler_layout: + args.append(get_auto_scheduler_rewritten_layout(attrs)) + return [topi_compute(*args)] + + + + return _compute_conv2d_biasadd_relu + + @override_native_generic_func("softmax_strategy") def softmax_strategy(attrs, inputs, out_type, target): """softmax generic strategy""" @@ -201,6 +321,7 @@ def schedule_bitpack(attrs, outs, target): ) # conv2d + def wrap_compute_conv2d( topi_compute, need_data_layout=False, @@ -233,6 +354,7 @@ def _compute_conv2d(attrs, inputs, out_type): return _compute_conv2d + @override_native_generic_func("conv2d_strategy") def conv2d_strategy(attrs, inputs, out_type, target): """conv2d generic strategy""" diff --git a/python/tvm/relay/testing/init.py b/python/tvm/relay/testing/init.py index f275712c77d1..d149977317d4 100644 --- a/python/tvm/relay/testing/init.py +++ b/python/tvm/relay/testing/init.py @@ -169,11 +169,14 @@ def create_workload(net, initializer=None, seed=0): mod = relay.transform.InferType()(mod) shape_dict = {v.name_hint: v.checked_type for v in mod["main"].params} np.random.seed(seed) - initializer = initializer if initializer else Xavier() + + initializer = initializer if initializer is not None else Xavier() + params = {} for k, v in shape_dict.items(): if k == "data": continue + #init_value = np.random.uniform(-1,1,size=v.concrete_shape).astype(v.dtype) init_value = np.zeros(v.concrete_shape).astype(v.dtype) initializer(k, init_value) params[k] = tvm.nd.array(init_value, device=tvm.cpu(0)) diff --git a/python/tvm/relay/transform/backend_operator/backend_op.py b/python/tvm/relay/transform/backend_operator/backend_op.py index 00d23a88dc2e..2bdd062a8c7d 100644 --- a/python/tvm/relay/transform/backend_operator/backend_op.py +++ b/python/tvm/relay/transform/backend_operator/backend_op.py @@ -15,14 +15,14 @@ from .pattern import Pattern from .utils import get_diamond -from .utils import is_call_node, is_tuplegetitem_node, is_var_node, no_constraints_func, is_constant_node +from .utils import * from .op_config import Config, MeasuredConfigs from .target import Target, get_target_cost_func from .op_type import OpType, optype_to_pattern, relayop_to_varnames # It gives the path of backend_op.py no matter where you import this file # cur_dir_path = Path(__file__).parent.absolute() -# RES_LOG = f"{cur_dir_path}/logs/runtime_results.log" +# RES_LOG = f"{cur_dir_path}/../logs/runtime_results.log" # redirect stdout to this log so it is not intertwined with by TVM backend log output # sys.stdout = open(RES_LOG, 'w') @@ -62,9 +62,13 @@ def get_cost(self, expr): config = Config(self._name, self._op_type.name(), expr) # print(config) + # For Tuple, we do not need to measure it + if is_tuple_node(expr) or is_tuplegetitem_node(expr): + return 0, 0 + # if constraints are not satisfied, return infinite cost if not self._constraint_func(config): - return float('inf') + return float('inf'), 0 cost_info = self._measured_configs.get_cost(config) if cost_info != None: @@ -104,12 +108,10 @@ def helper(expr, depth): if is_call_node(expr): # note that only call node has "op" attribute corresponding to a single backend operator op, args, attrs, type_args, span = expr.op, expr.args, expr.attrs, expr.type_args, expr.span - new_args = [] # at depth 1, turn call expr arguments into free variables with the same attributes and data shapes! if depth == 1: var_names = relayop_to_varnames[op.name] - # # of arguments should match # of type arguments # Fix: This happens in BERT. We need to deal with it # It means that type inference hasn't been executed (type_args are not filled) @@ -117,12 +119,18 @@ def helper(expr, depth): if len(expr.args) != len(expr.type_args): raise Exception("The type inference pass hasn't been executed.") else: + # print(expr.op, var_names) for i in range(len(expr.args)): type_arg = expr.type_args[i] var_name = var_names[i] - + + # Tuple should be treated separately + if (type(type_arg) is tvm.ir.type.TupleType): + input_data = expr.args[i] + #print(type_arg.fields) + new_args.append(relay.Tuple([relay.var(var_name, d) for i, d in enumerate(type_arg.fields)] )) # Bias should be constant - if var_name == 'bias': + elif var_name == 'bias': input_data = expr.args[i].data new_args.append(relay.Constant(input_data)) else: diff --git a/python/tvm/relay/transform/backend_operator/backend_op_lib.py b/python/tvm/relay/transform/backend_operator/backend_op_lib.py index 09ce6a7d6f56..95d824c1a655 100644 --- a/python/tvm/relay/transform/backend_operator/backend_op_lib.py +++ b/python/tvm/relay/transform/backend_operator/backend_op_lib.py @@ -104,24 +104,23 @@ def __init__(self): # Note that we only support ResNet50 for now def _add_all_backendops(self): # CUDNN - # self._add_backendop("cudnn_conv2d", Target.CUDNN, OpType.CONV2D, 1) + # FIXME(@Soo): For ResNext, some of CUDNN convolution doesn't work. + self._add_backendop("cudnn_conv2d", Target.CUDNN, OpType.CONV2D, 1) self._add_backendop("cudnn_relu", Target.CUDNN, OpType.RELU, 1) + self._add_backendop("cudnn_biasadd", Target.CUDNN, OpType.BIAS_ADD, 1) + + # Not implemented for recording # self._add_backendop("cudnn_add", Target.CUDNN, OpType.ADD, 1) - # self._add_backendop("cudnn_softmax", Target.CUDNN, OpType.SOFTMAX, 1) - # self._add_backendop("cudnn_biasadd", Target.CUDNN, OpType.BIAS_ADD, 1) + self._add_backendop("cudnn_softmax", Target.CUDNN, OpType.SOFTMAX, 1) # self._add_backendop("cudnn_bn", Target.CUDNN, OpType.BN, 1) - # measure_cost doesn't work, we need to fix this later. # self._add_backendop("cudnn_maxpool2d", Target.CUDNN, OpType.MAX_POOL2D, 1) - # conv_bias_add_relu --> ResNet doesn't have this pattern, so it wouldn't be measured - # self._add_backendop("cudnn_conv2d+biasadd+relu", Target.CUDNN, OpType.CONV2D_BIAS_ADD_RELU, 3) + self._add_backendop("cudnn_conv2d+biasadd+relu", Target.CUDNN, OpType.CONV2D_BIAS_ADD_RELU, 3) # TENSORRT - self._add_backendop("tensorrt_conv2d", Target.TENSORRT, OpType.CONV2D, 1) - self._add_backendop("tensorrt_relu", Target.TENSORRT, OpType.RELU, 1) - # add_all_backend_ops_to_lib(self, Target.TENSORRT) + add_all_backend_ops_to_lib(self, Target.TENSORRT) # CUBLAS # TODO: Add patterns. matmul, batch matmul diff --git a/python/tvm/relay/transform/backend_operator/logs/.gitignore b/python/tvm/relay/transform/backend_operator/logs/.gitignore new file mode 100644 index 000000000000..1e6e2782cd66 --- /dev/null +++ b/python/tvm/relay/transform/backend_operator/logs/.gitignore @@ -0,0 +1 @@ +olds \ No newline at end of file diff --git a/python/tvm/relay/transform/backend_operator/op_config.py b/python/tvm/relay/transform/backend_operator/op_config.py index b69a940944d6..d617770fc4ea 100644 --- a/python/tvm/relay/transform/backend_operator/op_config.py +++ b/python/tvm/relay/transform/backend_operator/op_config.py @@ -4,7 +4,7 @@ from os import path cur_dir_path = Path(__file__).parent.absolute() -COST_LOG = f"{cur_dir_path}/logs/operator_cost.log" +COST_LOG = f"{cur_dir_path}/../logs/operator_cost.log" # configuration includes operator name, operator type (backend operators from different targets might have the same type), # data shape of all free variables, and node attributes @@ -15,7 +15,7 @@ def __init__(self, op_name, op_type, expr, data_shape=None, attrs=None): self._op_type = op_type if expr != None: - self._data_shape = get_data_shape(expr) + self._data_shape = tuple(get_data_shape(expr)) self._attrs = extract_attrs(expr) else: # Debugging purpose diff --git a/python/tvm/relay/transform/backend_operator/op_type.py b/python/tvm/relay/transform/backend_operator/op_type.py index d97da94ad1c8..a8d4b8b248f6 100644 --- a/python/tvm/relay/transform/backend_operator/op_type.py +++ b/python/tvm/relay/transform/backend_operator/op_type.py @@ -8,24 +8,43 @@ # because they are only used to extract result of Relay's batch_norm operator class OpType(Enum): # ID, name, depth + # RESNE(X)T ADD = (0, 'add', 1) CONV2D = (1, 'conv2d', 1) RELU = (2, 'relu', 1,) - CONV2D_RELU = (15, 'conv2d+relu', 2) - - # Other ops than ResNet50 - DIAMOND = (13, 'diamond', 6) # Not sure yet if it works well for DP - BN = (3, 'bn', 1) - SOFTMAX = (4, 'softmax', 1) - BIAS_ADD = (5, 'biasadd', 1) - DENSE = (6, 'dense', 1) - BATCH_FLATTEN = (7, 'batchflatten', 1) - GLOBAL_AVG_POOL2D = (8, 'globalavgpool2d', 1) - MAX_POOL2D = (9, 'maxpool2d', 1) - CONV2D_BN = (10, 'conv2d+bn', 2) - BN_RELU = (11, 'bn+relu', 2) - CONV2D_BN_RELU = (12, 'conv2d+bn+relu', 3) - CONV2D_BIAS_ADD_RELU = (14, 'conv2d+biasadd+relu', 3) + CONV2D_RELU = (3, 'conv2d+relu', 2) + + # BERT + DENSE = (4, 'dense', 1) + RESHAPE = (5, 'reshape', 1) + TRANSPOSE = (6, 'transpose', 1) + BATCH_MATMUL = (7, 'batch_matmul', 1) + + # NASRNN + TANH = (8, 'tanh', 1) + SIGMOID = (9, 'sigmoid', 1) + MULTIPLY = (10, 'multiply', 1) + TUPLE_GET_ITEM_0 = (11, 'tuple_get_item_0', 1) + TUPLE_GET_ITEM_1 = (12, 'tuple_get_item_1', 1) + TUPLE_TWO_IDX = (13, 'tuple_two_idx', 1) + + # NASNET-A + CONCAT = (14, 'concat', 1) + BIAS_ADD = (15, 'biasadd', 1) + AVG_POOL2D = (16, 'avgpool2d', 1) + MAX_POOL2D = (17, 'maxpool2d', 1) + TUPLE_FOUR_IDX = (27, 'tuple_four_idx', 1) + + # Others + DIAMOND = (18, 'diamond', 6) # Not sure yet if it works well for DP + BN = (19, 'bn', 1) + SOFTMAX = (20, 'softmax', 1) + BATCH_FLATTEN = (21, 'batchflatten', 1) + GLOBAL_AVG_POOL2D = (22, 'globalavgpool2d', 1) + CONV2D_BN = (23, 'conv2d+bn', 2) + BN_RELU = (24, 'bn+relu', 2) + CONV2D_BN_RELU = (25, 'conv2d+bn+relu', 3) + CONV2D_BIAS_ADD_RELU = (26, 'conv2d+biasadd+relu', 3) def identifier(self): return self.value[0] @@ -38,20 +57,41 @@ def depth(self): # maps op type to pattern representing it optype_to_pattern = { + # RESNE(X)T OpType.ADD : Pattern(is_op('add')(wildcard(), wildcard())), OpType.CONV2D : Pattern(is_op("nn.conv2d")(wildcard(), wildcard())), OpType.RELU : Pattern(is_op("nn.relu")(wildcard())), OpType.CONV2D_RELU : Pattern(is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))), - # Other ops than ResNet50 + # BERT + OpType.DENSE : Pattern(is_op("nn.dense")(wildcard(), wildcard())), + OpType.RESHAPE : Pattern(is_op("reshape")(wildcard())), + OpType.TRANSPOSE : Pattern(is_op("transpose")(wildcard())), + OpType.BATCH_MATMUL : Pattern(is_op("nn.batch_matmul")(wildcard(),wildcard())), + + # NASRNN + OpType.TANH : Pattern(is_op("tanh")(wildcard())), + OpType.SIGMOID : Pattern(is_op("sigmoid")(wildcard())), + OpType.MULTIPLY : Pattern(is_op("multiply")(wildcard(), wildcard())), + OpType.TUPLE_GET_ITEM_0 : Pattern(is_tuple_get_item(wildcard(), 0)), + OpType.TUPLE_GET_ITEM_1 : Pattern(is_tuple_get_item(wildcard(), 1)), + OpType.TUPLE_TWO_IDX : Pattern(is_tuple([wildcard(), wildcard()])), + + # NASNET-A + OpType.CONCAT : Pattern(is_op("concatenate")(wildcard())), + OpType.BIAS_ADD : Pattern(is_op("nn.bias_add")(wildcard(), wildcard())), + OpType.AVG_POOL2D : Pattern(is_op("nn.avg_pool2d")(wildcard())), + OpType.MAX_POOL2D : Pattern(is_op("nn.max_pool2d")(wildcard())), + OpType.TUPLE_FOUR_IDX : Pattern(is_tuple([wildcard(), wildcard(), wildcard(), wildcard(), wildcard()])), + + # Others OpType.DIAMOND : get_diamond(), OpType.BN : Pattern(is_tuple_get_item(is_op("nn.batch_norm")(wildcard(), wildcard(), wildcard(), wildcard(), wildcard()), 0)), OpType.SOFTMAX : Pattern(is_op("nn.softmax")(wildcard())), - OpType.BIAS_ADD : Pattern(is_op("nn.bias_add")(wildcard(), wildcard())), - OpType.DENSE : Pattern(is_op("nn.dense")(wildcard(), wildcard())), OpType.BATCH_FLATTEN : Pattern(is_op("nn.batch_flatten")(wildcard())), OpType.GLOBAL_AVG_POOL2D : Pattern(is_op("nn.global_avg_pool2d")(wildcard())), - OpType.MAX_POOL2D : Pattern(is_op("nn.max_pool2d")(wildcard())), + + # Other Fused Ops OpType.CONV2D_BN : Pattern(is_tuple_get_item(is_op("nn.batch_norm")(is_op("nn.conv2d")(wildcard(), wildcard()), wildcard(), wildcard(), wildcard(), wildcard()), 0)), OpType.BN_RELU : Pattern(is_op("nn.relu")(is_tuple_get_item(is_op("nn.batch_norm")(wildcard(), wildcard(), wildcard(), wildcard(), wildcard()), 0))), OpType.CONV2D_BN_RELU : Pattern(is_op("nn.relu")(is_tuple_get_item(is_op("nn.batch_norm")(is_op("nn.conv2d")(wildcard(), wildcard()), wildcard(), wildcard(), wildcard(), wildcard()), 0))), @@ -60,17 +100,36 @@ def depth(self): # maps relay operator type to names of input vars. relayop_to_varnames = { + # RESNE(X)T "add" : ["data", "data"], "nn.conv2d" : ["data", "weight"], "nn.relu": ["data"], - # Other ops than ResNet50 + # BERT + "nn.dense" : ["data", "weight"], + "reshape": ["data"], + "transpose": ["data"], + "nn.batch_matmul" : ["data", "data"], + #"nn.batch_matmul" : ["x", "y"], + + # NASRNN + "tanh": ["data"], + "multiply": ["data", "data"], + # "multiply": ["lhs", "rhs"], + "sigmoid": ["data"], + # FIXME(@Soo): How should we deal with TUPLE and TUPLE_GET_ITEM? + + # NASNET-A + "concatenate": ["data"], + "nn.bias_add" : ["data", "bias"], + "nn.avg_pool2d" : ["data"], + "nn.max_pool2d" : ["data"], + "tuple" : ["data", "data", "data", "data", "data"], + + # Others "nn.batch_norm" : ["data", "bn_data_gamma", "bn_data_beta", "bn_data_moving_mean", "bn_data_moving_var"], "nn.softmax" : ["data"], - "nn.bias_add" : ["data", "bias"], - "nn.dense" : ["data", "weight"], "nn.batch_flatten" : ["data"], "nn.global_avg_pool2d" : ["data"], - "nn.max_pool2d" : ["data"], } diff --git a/python/tvm/relay/transform/backend_operator/plot_ops.py b/python/tvm/relay/transform/backend_operator/plot_ops.py new file mode 100644 index 000000000000..d72840adf848 --- /dev/null +++ b/python/tvm/relay/transform/backend_operator/plot_ops.py @@ -0,0 +1,130 @@ +from .op_config import MeasuredConfigs +import os +import matplotlib.pyplot as plt +import pandas as pd + +import pickle +from .op_type import OpType +from .op_config import Config +# from target import Target + +fw, fh = 15, 4 + +def gen_op_key(key): + return f"{key._op_type}, {key._data_shape}, {key._attrs}" + +def set_plt_font_size(): + SMALL_SIZE = 14 + MEDIUM_SIZE = 16 + BIGGER_SIZE = 22 + + plt.style.use('seaborn-paper') + plt.rc('font', size=BIGGER_SIZE) # controls default text sizes + plt.rc('axes', titlesize=MEDIUM_SIZE) # fontsize of the axes title + plt.rc('axes', labelsize=BIGGER_SIZE) # fontsize of the x and y labels + plt.rc('xtick', labelsize=MEDIUM_SIZE) # fontsize of the tick labels + plt.rc('ytick', labelsize=MEDIUM_SIZE) # fontsize of the tick labels + plt.rc('legend', fontsize=MEDIUM_SIZE) # legend fontsize + plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title + plt.rc('axes', linewidth=2) + plt.rc('lines', linewidth=3) + +def get_processed_dp(measured_configs, target_batch_size): + op2target_perf = {} + # Ignore std of perf + for config, (perf, _) in measured_configs.measured_configs.items(): + op_key = gen_op_key(config) + # if config._data_shape[0] != target_batch_size: + # continue + target_name = config._op_name.split("_")[0] + if op_key in op2target_perf: + if target_name not in op2target_perf: + op2target_perf[op_key][target_name] = perf + else: + op2target_perf[op_key] = {target_name: perf} + + # Set new op name with ID + op2id = {} + new_op2target_perf = {} + for key, val in op2target_perf.items(): + op_name = key.split(",")[0] + if op_name not in op2id: + op2id[op_name] = 1 + else: + op2id[op_name]+= 1 + op_name = f"{op_name}_{op2id[op_name]}" + + # Check why auto-tuned TVM_GPU operators fall behind a lot + # if op_name in ['conv2d_4', 'conv2d_7', 'conv2d_10']: + # print(key) + new_op2target_perf[op_name] = val + + return pd.DataFrame(new_op2target_perf).T + +def filter_df_with_regex(df, regex, cols_to_exclude): + filtered_df = df.filter(regex=regex, axis=0) + filtered_df = filtered_df[filtered_df.columns.difference(cols_to_exclude)] + + return filtered_df + +def draw_plot(df, fig_name): + df.plot.bar(figsize=(fw, fh)) + + x_label_invisible = False + + # Save figures + plt.xlabel('Operators') + plt.ylabel('Inference Time') + + this_code_path = os.path.dirname(os.path.abspath(__file__)) + fig_name = f'{this_code_path}/../results/plots/{fig_name}' + if x_label_invisible: + ax1 = plt.axes() + x_axis = ax1.axes.get_xaxis() + x_axis.set_visible(False) + + plt.savefig(fig_name, bbox_inches='tight') + + +if __name__ == "__main__": + target_batch_size = 1 + set_plt_font_size() + + measured_configs = MeasuredConfigs() + measured_configs.load_from_log() + # measured_configs.measured_configs + + df = get_processed_dp(measured_configs, target_batch_size) + + # df['tensorrt'] = df['tvmgpu']/df['tensorrt'] + # df['cudnn'] = df['tvmgpu']/df['cudnn'] + # df['cublas'] = df['tvmgpu']/df['cublas'] + # # df['tvmcpu'] /= df['tvmgpu'] + # df['tvmgpu'] /= df['tvmgpu'] + + # Conv GPU plots + conv_df = filter_df_with_regex(df=df, regex = 'conv2d_\d{1,2}', + cols_to_exclude=['tvmcpu', 'cublas', 'cudnn', 'tensorrt'])#'tvmgpu-no-tuning']) + draw_plot(df=conv_df, fig_name=f'conv_rtx_bn{target_batch_size}.png') + + # Matmul (Dense) GPU plot + # batch_matmul_\d{1,2} + # print(df[df.index.str.match('conv*')== False]) + dense_df = filter_df_with_regex(df=df, regex='(:?dense_\d{1,2}|batch_matmul_\d{1,2})', + cols_to_exclude=['tvmcpu', 'cublas', 'cudnn']) + # print(dense_df) + draw_plot(df=dense_df, fig_name=f'matmul_rtx_bn{target_batch_size}.png') + + # # Fused ops + # fused_df = df.filter(like='+', axis=0) + # fused_df = fused_df[fused_df.columns.difference(['cudnn', 'cublas', 'tvmcpu', 'tvmgpu-no-tuning'])] + # draw_plot(df=fused_df, fig_name=f'fused_rtx_bn{target_batch_size}.png') + # + # # Other ops + # # forbidden_str = ['+','conv2d','dense'] + # # re_str = '|'.join(forbidden_str) + # drop_str = [f"conv2d_{i}" for i in range(1, 21)] + [val for val in df.index.values if "+" in val] + ['dense_1'] + # other_df = df.drop(index=drop_str) + # other_df = other_df[other_df.columns.difference(['cublas', 'tvmcpu', 'tvmgpu-no-tuning'])] + # + # draw_plot(df=other_df, fig_name=f'others_rtx_bn{target_batch_size}.png') diff --git a/python/tvm/relay/transform/backend_operator/target.py b/python/tvm/relay/transform/backend_operator/target.py index 157738555ccf..0c255b0f0a8a 100644 --- a/python/tvm/relay/transform/backend_operator/target.py +++ b/python/tvm/relay/transform/backend_operator/target.py @@ -9,9 +9,10 @@ import os from pathlib import Path -from .utils import is_call_node, is_tuplegetitem_node, is_var_node, no_constraints_func, get_data_shape +from .utils import * -from tvm.contrib import graph_executor +from tvm.contrib import graph_executor as runtime +# from tvm.contrib import graph_executor # only collect results whose standard deviation is below this MAX_STANDARD_DEVIATION = 5E-04 @@ -20,10 +21,11 @@ cur_dir_path = Path(__file__).parent.absolute() -AUTOTVM_LOG = f"{cur_dir_path}/../autotune/autotvm_ops.log" +AUTOTVM_LOG = f"{cur_dir_path}/../logs/autotvm_ops.log" # Temporary autoscheduler log file # FIXME(@Soo): Accumulate autoscheduler logs to the same file -AUTOSCH_LOG = f"{cur_dir_path}/../autotune/autosch_ops.json" +# AUTOSCH_LOG = "/home/byungsoj/backend-aware-graph-opt/package/autotune/tmp/autosch_ops.json.resnet50.tmp" +AUTOSCH_LOG = f"{cur_dir_path}/../logs/autosch_ops.json" def measure(ftimer, *args): # Warm-up Phase: Run without measurement @@ -92,18 +94,18 @@ def measure_cost(name, expr, target): # AutoScheduler codes target_str = target.__str__() - ctx = tvm.context(target_str, 0) with auto_scheduler.ApplyHistoryBest(AUTOSCH_LOG): with tvm.transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}): lib = relay.build(net, target_str, params=params) - module = runtime.GraphModule(lib["default"](ctx)) + dev = tvm.device(target_str, 0) + module = runtime.GraphModule(lib["default"](dev)) # Setup execution data_shape = get_data_shape(expr) data = np.random.uniform(-1, 1, size=data_shape).astype("float32") module.set_input("data", data) - ftimer = module.module.time_evaluator("run", ctx, number=NUM_MEASUREMENTS_PER_REPEAT, repeat=NUM_REPEATS) + ftimer = module.module.time_evaluator("run", dev, number=NUM_MEASUREMENTS_PER_REPEAT, repeat=NUM_REPEATS) return measure(ftimer) @@ -126,15 +128,17 @@ def measure_cost(name, expr, target): # Compile kernels with history best records with autotvm.apply_history_best(AUTOTVM_LOG): target_str = target.__str__() - ctx = tvm.context(target_str, 0) - lib = relay.build_module.build(net, target_str, params=params) - module = runtime.GraphModule(lib["default"](ctx)) + with tvm.transform.PassContext(opt_level=3): + lib = relay.build_module.build(net, target=target_str, params=params) + + dev = tvm.device(str(target), 0) + module = runtime.GraphModule(lib["default"](dev)) # Setup execution data_shape = get_data_shape(expr) data = np.random.uniform(-1, 1, size=data_shape).astype("float32") module.set_input("data", data) - ftimer = module.module.time_evaluator("run", ctx, number=NUM_MEASUREMENTS_PER_REPEAT, repeat=NUM_REPEATS) + ftimer = module.module.time_evaluator("run", dev, number=NUM_MEASUREMENTS_PER_REPEAT, repeat=NUM_REPEATS) return measure(ftimer) @@ -153,18 +157,20 @@ def measure_cost(name, expr, target): # Build the subgraph # FIXME(@Soo): We should redesign Target class to deal with new TVM build interface - target_backend = tvm.target.cuda() - target_dev = tvm.gpu() + target_str = target.__str__() opt_level = 3 with tvm.transform.PassContext(opt_level=opt_level): - lib = relay.build(net, target_backend, params=params) + lib = relay.build(net, target_str, params=params) + + dev = tvm.device(str(target), 0) + module = runtime.GraphModule(lib["default"](dev)) # Setup execution data_shape = get_data_shape(expr) - data = np.random.uniform(-1, 1, size=data_shape).astype("float32") - module = graph_executor.GraphModule(lib["default"](target_dev)) + #data = np.random.uniform(-1, 1, size=data_shape).astype("float32") + data = get_data(expr) module.set_input("data", data) - ftimer = module.module.time_evaluator("run", target_dev, number=NUM_MEASUREMENTS_PER_REPEAT, repeat=NUM_REPEATS) + ftimer = module.module.time_evaluator("run", dev, number=NUM_MEASUREMENTS_PER_REPEAT, repeat=NUM_REPEATS) return measure(ftimer) @@ -181,6 +187,20 @@ def measure_cost(name, expr, target): # # return measure(ftimer) +def get_conv_attr(expr): + assert (is_call_node(expr)) + # note that only call node has "op" attribute corresponding to a single backend operator + op, args, attrs, type_args, span = expr.op, expr.args, expr.attrs, expr.type_args, expr.span + + # extract conv attributes + strides, padding, out_channels, dilation = \ + list(attrs.strides), list(attrs.padding), int(attrs.channels), list(attrs.dilation) + + kernel_size = args[1].type_annotation.shape + dtype = args[0].type_annotation.dtype + + return strides, padding, out_channels, dilation, kernel_size, dtype + class CuDNNCostFunc(TargetCostFunc): def __init__(self): @@ -228,16 +248,14 @@ def measure_cost(name, expr, target): in_channels = data_shape[1] if "conv2d" in op_name: - assert(is_call_node(expr)) - # note that only call node has "op" attribute corresponding to a single backend operator - op, args, attrs, type_args, span = expr.op, expr.args, expr.attrs, expr.type_args, expr.span - - # extract conv attributes - strides, padding, out_channels, dilation = \ - list(attrs.strides), list(attrs.padding), int(attrs.channels), list(attrs.dilation) - kernel_size = args[1].type_annotation.shape - dtype = args[0].type_annotation.dtype + if op_name == "conv2d+biasadd+relu": + strides, padding, out_channels, dilation, kernel_size, dtype = get_conv_attr(expr.args[0].args[0]) + print(strides, padding, out_channels, dilation, kernel_size, dtype) + elif op_name == 'conv2d': + strides, padding, out_channels, dilation, kernel_size, dtype = get_conv_attr(expr) + else: + raise Exception(f"{op_name} is not supported for CUDNN") assert(dtype == "float32") @@ -311,40 +329,46 @@ def measure_cost(name, expr, target): perf = measure(ftimer, data, weight, output) - elif op_name == "conv2d+bias+relu": + elif op_name == "conv2d+biasadd+relu": + # Warning: We assuem that args[1] corresponds to bias + bias_tensor = expr.args[0].args[1].data + te_data = te.placeholder(data_shape, name="data", dtype=dtype) te_kernel = te.placeholder(kernel_size, name="kernel", dtype=dtype) te_z = te.placeholder(output_shape, name="Z", dtype=dtype) - te_bias = te.placeholder(params["bias"].shape, name="bias", dtype=dtype) + + # Note that bias is a constant and not in params cuz it's a constant + # te_bias = te.placeholder(params["bias"].shape, name="bias", dtype=dtype) + te_bias = te.placeholder(bias_tensor.shape, name="bias", dtype=dtype) cuDNN_OP = te.extern( output_shape, [te_data, te_kernel, te_z, te_bias], lambda ins, outs: tvm.tir.call_packed( - "tvm.contrib.cudnn.pooling.forward", - conv_mode, # mode: CUDNN_CONVOLUTION - data_layout, # CUDNN_TENSOR_NCHW - conv_algo, # ALGO - padding[0], padding[1], - strides[0], strides[1], - dilation[0], dilation[1], - dtype, - ins[0], # x - ins[1], # w - ins[2], # z - ins[3], # bias - outs[0], # y - groups, - 1,#alphas[0], - 0,#alphas[1], - 1,#alphas[0] for z - 0, - activation_mode, - nanProp_mode, - actvCoeff - ), - name="y", - ) + "tvm.contrib.cudnn.conv2d+bias+activation.forward", + conv_mode, # mode: CUDNN_CONVOLUTION + data_layout, # CUDNN_TENSOR_NCHW + conv_algo, + padding[0], padding[1], + strides[0], strides[1], + dilation[0], dilation[1], + dtype, + ins[0], # x + ins[1], # w + ins[2], # z + ins[3], # bias + outs[0], # y + groups, + 1, # alphas[0], + 0, # alphas[1], + 1, # alphas[0] for z + 0, + activation_mode, + nanProp_mode, + actvCoeff + ), + name="y", + ) s = te.create_schedule(cuDNN_OP.op) func = tvm.build(s, [te_data, te_kernel, te_z, te_bias, cuDNN_OP], target_str, target_host="llvm") @@ -353,7 +377,8 @@ def measure_cost(name, expr, target): data = tvm.nd.array(data_in, ctx) weight = tvm.nd.array(params["weight"], ctx) ze = tvm.nd.array(np.zeros(output_shape, dtype=dtype), ctx) - bias = tvm.nd.array(params["bias"], ctx) + # bias = tvm.nd.array(params["bias"], ctx) + bias = tvm.nd.array(bias_tensor.asnumpy(), ctx) output = tvm.nd.array(np.zeros(output_shape, dtype=dtype), ctx) ftimer = func.time_evaluator(func.entry_name, ctx, number=NUM_MEASUREMENTS_PER_REPEAT, repeat=NUM_REPEATS) @@ -389,9 +414,13 @@ def measure_cost(name, expr, target): elif op_name == "biasadd": axis = expr.attrs.axis + # Warning: We assuem that args[1] corresponds to bias + bias_tensor = expr.args[1].data + # Note that bias is a constant and not in params cuz it's a constant te_data = te.placeholder(data_shape, name="data", dtype=dtype) - te_bias = te.placeholder(params["bias"].shape, name="bias", dtype=dtype) + # te_bias = te.placeholder(params["bias"].shape, name="bias", dtype=dtype) + te_bias = te.placeholder(bias_tensor.shape, name="bias", dtype=dtype) output_shape = data_shape cuDNN_OP = te.extern( @@ -411,7 +440,8 @@ def measure_cost(name, expr, target): data_in = np.random.uniform(-1, 1, size=data_shape).astype(dtype) data = tvm.nd.array(data_in, ctx) - bias = tvm.nd.array(params["bias"], ctx) + # bias = tvm.nd.array(params["bias"], ctx) + bias = tvm.nd.array(bias_tensor.asnumpy(), ctx) output = tvm.nd.array(np.zeros(output_shape, dtype=dtype), ctx) ftimer = func.time_evaluator(func.entry_name, ctx, number=NUM_MEASUREMENTS_PER_REPEAT, repeat=NUM_REPEATS) diff --git a/python/tvm/relay/transform/backend_operator/test_ops.py b/python/tvm/relay/transform/backend_operator/test_ops.py new file mode 100644 index 000000000000..91d6b7eec666 --- /dev/null +++ b/python/tvm/relay/transform/backend_operator/test_ops.py @@ -0,0 +1,877 @@ +import tvm +from tvm import te +import numpy as np +import tvm.contrib.graph_executor as runtime +#from tvm.contrib import graph_runtime as runtime +from tvm import relay +from tvm.relay import testing +import pytest as pyt + +from .target import Target +from .utils import is_function_node +from ..workloads.onnx_workloads import get_network_from_onnx + + +gt_target = "cuda" + +def genKey(config): + key = "" + for e in config: + if isinstance(e, str): + key += e + else: + key += str(e) + key += ", " + return key + +def get_gt_net(config): + op = config["op"] + target = gt_target + batch_size = config["batch_size"] + data_shape = config["data_shape"] + out_channels = config["out_channels"] + kernel_size = config["kernel_size"] + strides = config["strides"] + padding = config["padding"] + dilation = config["dilation"] + groups = config["groups"] + data_layout = config["data_layout"] + kernel_layout = config["kernel_layout"] + out_layout = config["out_layout"] + out_dtype = config["out_dtype"] + pool_size = config["pool_size"] + axis = config["axis"] + + + key = genKey(config) + + # Define input tensor shapes and variables + data = relay.var("data", relay.TensorType(data_shape, "float32")) + weight = relay.var("weight") + bn_gamma = relay.var("bn_gamma") + bn_beta = relay.var("bn_beta") + bn_mmean = relay.var("bn_mean") + bn_mvar = relay.var("bn_var") + bias = relay.var("bias") + + # Process given operators + if op == "conv2d": + simple_net = relay.nn.conv2d( + data = data, + weight = weight, + strides = strides, + padding = padding, + dilation = dilation, + groups = groups, + channels = out_channels, + kernel_size = kernel_size, + data_layout = data_layout, + kernel_layout = kernel_layout, + out_layout = out_layout, + out_dtype = "", + ) + elif op == "bn": + simple_net = relay.nn.batch_norm(data, bn_gamma, bn_beta, bn_mmean, bn_mvar, axis)[0] + elif op == "relu": + simple_net = relay.nn.relu(data) + elif op == "conv2d+bn": + simple_net = relay.nn.conv2d( + data = data, + weight = weight, + strides = strides, + padding = padding, + dilation = dilation, + groups = groups, + channels = out_channels, + kernel_size = kernel_size, + data_layout = data_layout, + kernel_layout = kernel_layout, + out_layout = out_layout, + out_dtype = "", + ) + simple_net = relay.nn.batch_norm(simple_net, bn_gamma, bn_beta, bn_mmean, bn_mvar)[0] + elif op == "bn+relu": + simple_net = relay.nn.batch_norm(data, bn_gamma, bn_beta, bn_mmean, bn_mvar)[0] + simple_net = relay.nn.relu(simple_net) + + elif op == "biasadd": + simple_net = relay.nn.bias_add( + data = data, + bias = bias, + axis = axis + ) + elif op == "conv2d+bn+relu": + simple_net = relay.nn.conv2d( + data = data, + weight = weight, + strides = strides, + padding = padding, + dilation = dilation, + groups = groups, + channels = out_channels, + kernel_size = kernel_size, + data_layout = data_layout, + kernel_layout = kernel_layout, + out_layout = out_layout, + out_dtype = "", + ) + simple_net = relay.nn.batch_norm(simple_net, bn_gamma, bn_beta, bn_mmean, bn_mvar)[0] + simple_net = relay.nn.relu(simple_net) + elif op == "conv2d+bias+relu": + simple_net = relay.nn.conv2d( + data = data, + weight = weight, # conv kernel + strides = strides, + padding = padding, + dilation = dilation, + groups = groups, + channels = out_channels, + kernel_size = kernel_size, + data_layout = data_layout, + kernel_layout = kernel_layout, + out_layout = out_layout, + out_dtype = "", + ) + simple_net = relay.nn.bias_add(simple_net, bias) + simple_net = relay.nn.relu(simple_net) + + elif op == "softmax": + simple_net = relay.nn.softmax( + data = data, + axis = axis + ) + + + elif op == "maxpool2d": + simple_net = relay.nn.max_pool2d( + data = data, + pool_size = pool_size, + strides = strides, + padding = padding, + ) + + # Create workload + inputs = relay.analysis.free_vars(simple_net) + simple_net = relay.Function(inputs, simple_net) + + return simple_net + + + +def ref_tvm_build_cudnn(config): + def impl(neural_in): + op = config["op"] + target = gt_target + batch_size = config["batch_size"] + data_shape = config["data_shape"] + out_channels = config["out_channels"] + kernel_size = config["kernel_size"] + strides = config["strides"] + padding = config["padding"] + dilation = config["dilation"] + groups = config["groups"] + data_layout = config["data_layout"] + kernel_layout = config["kernel_layout"] + out_layout = config["out_layout"] + out_dtype = config["out_dtype"] + pool_size = config["pool_size"] + axis = config["axis"] + + + key = genKey(config) + + # Define input tensor shapes and variables + data = relay.var("data", relay.TensorType(data_shape, "float32")) + weight = relay.var("weight") + bn_gamma = relay.var("bn_gamma") + bn_beta = relay.var("bn_beta") + bn_mmean = relay.var("bn_mean") + bn_mvar = relay.var("bn_var") + bias = relay.var("bias") + + # Process given operators + if op == "conv2d": + simple_net = relay.nn.conv2d( + data = data, + weight = weight, + strides = strides, + padding = padding, + dilation = dilation, + groups = groups, + channels = out_channels, + kernel_size = kernel_size, + data_layout = data_layout, + kernel_layout = kernel_layout, + out_layout = out_layout, + out_dtype = "", + ) + elif op == "bn": + simple_net = relay.nn.batch_norm(data, bn_gamma, bn_beta, bn_mmean, bn_mvar, axis)[0] + elif op == "relu": + simple_net = relay.nn.relu(data) + elif op == "conv2d+bn": + simple_net = relay.nn.conv2d( + data = data, + weight = weight, + strides = strides, + padding = padding, + dilation = dilation, + groups = groups, + channels = out_channels, + kernel_size = kernel_size, + data_layout = data_layout, + kernel_layout = kernel_layout, + out_layout = out_layout, + out_dtype = "", + ) + simple_net = relay.nn.batch_norm(simple_net, bn_gamma, bn_beta, bn_mmean, bn_mvar)[0] + elif op == "bn+relu": + simple_net = relay.nn.batch_norm(data, bn_gamma, bn_beta, bn_mmean, bn_mvar)[0] + simple_net = relay.nn.relu(simple_net) + + elif op == "biasadd": + simple_net = relay.nn.bias_add( + data = data, + bias = bias, + axis = axis + ) + elif op == "conv2d+bn+relu": + simple_net = relay.nn.conv2d( + data = data, + weight = weight, + strides = strides, + padding = padding, + dilation = dilation, + groups = groups, + channels = out_channels, + kernel_size = kernel_size, + data_layout = data_layout, + kernel_layout = kernel_layout, + out_layout = out_layout, + out_dtype = "", + ) + simple_net = relay.nn.batch_norm(simple_net, bn_gamma, bn_beta, bn_mmean, bn_mvar)[0] + simple_net = relay.nn.relu(simple_net) + elif op == "conv2d+bias+relu": + simple_net = relay.nn.conv2d( + data = data, + weight = weight, # conv kernel + strides = strides, + padding = padding, + dilation = dilation, + groups = groups, + channels = out_channels, + kernel_size = kernel_size, + data_layout = data_layout, + kernel_layout = kernel_layout, + out_layout = out_layout, + out_dtype = "", + ) + simple_net = relay.nn.bias_add(simple_net, bias) + simple_net = relay.nn.relu(simple_net) + + elif op == "softmax": + simple_net = relay.nn.softmax( + data = data, + axis = axis + ) + + + elif op == "maxpool2d": + simple_net = relay.nn.max_pool2d( + data = data, + pool_size = pool_size, + strides = strides, + padding = padding, + ) + + # Create workload + inputs = relay.analysis.free_vars(simple_net) + simple_net = relay.Function(inputs, simple_net) + + simple_net, params = testing.create_workload(simple_net) + simple_net = simple_net["main"].with_attr("CustomFusionPass", 1) + + + params = neural_in["params"] + opt_level = 2 + with tvm.transform.PassContext(opt_level=opt_level): + lib = relay.build(simple_net, tvm.target.cuda(), params=params) + + dev = tvm.device("cuda", 0) + #dev = tvm.device("cuda -libs=cudnn", 0) + #lib = relay.build_module.build(simple_net, "cuda") + mod = runtime.GraphModule(lib["default"](dev)) + mod.set_input("data", neural_in["data"]) + mod.set_input(**params) + mod.run() + return mod.get_output(0) + + return impl + + + + +def ref_tvm_op_build_cudnn(config): + from tvm.contrib import cudnn + + def check_implementation(op_name): + if not tvm.get_global_func(op_name, allow_missing=True): + raise Exception("Not compiled with fused cudnn support; can't build this tutorial") + + def impl(neural_in): + op = config["op"] + target = config["target"] + batch_size = config["batch_size"] + data_shape = config["data_shape"] + out_channels = config["out_channels"] + kernel_size = config["kernel_size"] + strides = config["strides"] + padding = config["padding"] + dilation = config["dilation"] + groups = config["groups"] + data_layout = config["data_layout"] + kernel_layout = config["kernel_layout"] + out_layout = config["out_layout"] + out_dtype = config["out_dtype"] + pool_size = config["pool_size"] + axis = config["axis"] + + # NOTE: CUDNN SPECIFIC CONFIGS + conv_mode = 1 # mode: CUDNN_CONVOLUTION + conv_algo = -1 # pick the best performing one via measurement + activation_mode = 1 # CUDNN_RELU + nanProp_mode = 0 # CUDNN_NOT_PROPAGATE_NAN + full_dims = 4 + dims = full_dims-2 + actvCoeff = 1e100 + dtype = 'float32' + + # bias_shape == (1,out_channels) + + + key = genKey(config) + + # NOTE: Currently, only supports certain cases + if data_layout == "NCHW": + data_layout = 0 + in_channels = data_shape[1] + else: + assert(0) + + if len(kernel_size) == 2: + kernel_size = [ out_channels, in_channels, *kernel_size ] + + + # params + params = neural_in["params"] + dev = tvm.device(target, 0) + + + if "conv2d" in op: + + output_shape = cudnn.conv_output_shape( + data_layout, + padding, + strides, + dilation, + list(data_shape), + list(kernel_size), + dtype, + dtype, + groups + ) + + if conv_algo == -1: + # For now if we try to call `cudnnFindConvolutionForwardAlgorithm` when + # using INT8 data type, CuDNN will crash down. + # On the other hand, CuDNN only support IMPLICIT_PRECOMP_GEMM at NHWC format + if data_layout == 1 and conv_dtype == "int32": + conv_algo = 1 + else: + conv_algo = cudnn.conv_find_algo( + data_layout, + padding, + strides, + dilation, + list(data_shape), + list(kernel_size), + output_shape, + dtype, + dtype, + groups + ) + + + # Process given operators + if op == "conv2d+bias+relu": + #padding, strides, dilation, _, _ = cudnn._prepare_global_func_params(dims, padding, strides, dilation) + + # Define input tensor shapes and variables + te_data = te.placeholder(data_shape, name="data", dtype=dtype) + te_kernel = te.placeholder(kernel_size, name="kernel", dtype=dtype) + te_z = te.placeholder(output_shape, name="Z", dtype=dtype) + te_bias = te.placeholder(params["bias"].shape, name="bias", dtype=dtype) + + cuDNN_OP = te.extern( + output_shape, + [te_data, te_kernel, te_z, te_bias], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.cudnn.conv2d+bias+activation.forward", + conv_mode, # mode: CUDNN_CONVOLUTION + data_layout, # CUDNN_TENSOR_NCHW + conv_algo, + padding[0], padding[1], + strides[0], strides[1], + dilation[0], dilation[1], + dtype, + ins[0], # x + ins[1], # w + ins[2], # z + ins[3], # bias + outs[0], # y + groups, + 1,#alphas[0], + 0,#alphas[1], + 1,#alphas[0] for z + 0, + activation_mode, + nanProp_mode, + actvCoeff + ), + name="y", + ) + + s = te.create_schedule(cuDNN_OP.op) + func = tvm.build(s, [te_data, te_kernel, te_z, te_bias, cuDNN_OP], "cuda -libs=cudnn", target_host="llvm") + + # convert np.ndarray to tvm.nd.array + data = tvm.nd.array(neural_in["data"], dev) + weight = tvm.nd.array(params["weight"], dev) + ze = tvm.nd.array(np.zeros(output_shape, dtype=dtype), dev) + bias = tvm.nd.array(params["bias"], dev) + output = tvm.nd.array(np.zeros(output_shape, dtype=dtype), dev) + func(data, weight, ze, bias, output) + + elif op == "conv2d": + # Define input tensor shapes and variables + te_data = te.placeholder(data_shape, name="data", dtype=dtype) + te_kernel = te.placeholder(kernel_size, name="kernel", dtype=dtype) + + cuDNN_OP = te.extern( + output_shape, + [te_data, te_kernel], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.cudnn.conv2d.forward", + conv_mode, # mode: CUDNN_CONVOLUTION + data_layout, # CUDNN_TENSOR_NCHW + conv_algo, + padding[0], padding[1], + strides[0], strides[1], + dilation[0], dilation[1], + ins[0], # x + ins[1], # w + outs[0], # y + dtype, + groups, + ), + name="y", + ) + + s = te.create_schedule(cuDNN_OP.op) + func = tvm.build(s, [te_data, te_kernel, cuDNN_OP], "cuda -libs=cudnn", target_host="llvm") + + data = tvm.nd.array(neural_in["data"], dev) + weight = tvm.nd.array(params["weight"], dev) + output = tvm.nd.array(np.zeros(output_shape, dtype=dtype), dev) + + func(data, weight, output) + + elif op == "softmax": + te_data = te.placeholder(data_shape, name="data", dtype=dtype) + output_shape = data_shape + cuDNN_OP = te.extern( + output_shape, + [te_data], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.cudnn.softmax.forward", + ins[0], # x + outs[0], # y + axis + ), + name="y", + ) + s = te.create_schedule(cuDNN_OP.op) + func = tvm.build(s, [te_data, cuDNN_OP], "cuda -libs=cudnn", target_host="llvm") + data = tvm.nd.array(neural_in["data"], dev) + output = tvm.nd.array(np.zeros(output_shape, dtype=dtype), dev) + + func(data, output) + + elif op == "biasadd" or op == "add": + te_data = te.placeholder(data_shape, name="data", dtype=dtype) + te_bias = te.placeholder(params["bias"].shape, name="bias", dtype=dtype) + output_shape = data_shape + + #assert(axis==-1 or axis==1) + cuDNN_OP = te.extern( + output_shape, + [te_data, te_bias], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.cudnn.add", + ins[0], # x + outs[0], # y + 1,0, # alpha, beta + axis + ), + name="y", + ) + s = te.create_schedule(cuDNN_OP.op) + func = tvm.build(s, [te_data, te_bias, cuDNN_OP], "cuda -libs=cudnn", target_host="llvm") + + data = tvm.nd.array(neural_in["data"], dev) + bias = tvm.nd.array(params["bias"], dev) + output = tvm.nd.array(np.zeros(output_shape, dtype=dtype), dev) + + func(data, bias, output) + + + elif op == "relu": + te_data = te.placeholder(data_shape, name="data", dtype=dtype) + output_shape = data_shape + cuDNN_OP = te.extern( + output_shape, + [te_data], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.cudnn.activation.forward", + ins[0], # x + outs[0], # y, + 1,0, #alpha, beta + activation_mode, + nanProp_mode, + actvCoeff + ), + name="y", + ) + s = te.create_schedule(cuDNN_OP.op) + func = tvm.build(s, [te_data, cuDNN_OP], "cuda -libs=cudnn", target_host="llvm") + data = tvm.nd.array(neural_in["data"], dev) + output = tvm.nd.array(np.zeros(output_shape, dtype=dtype), dev) + + func(data, output) + + + + elif op == "maxpool2d": + te_data = te.placeholder(data_shape, name="data", dtype=dtype) + output_shape = list(data_shape) + #outputDim = 1 + (inputDim + 2*padding - windowDim)/poolingStride; + for i in range(dims): + output_shape[i+2] = int(1 + (data_shape[i+2] + 2*padding[i]-pool_size[i])/strides[i]) + + + cuDNN_OP = te.extern( + output_shape, + [te_data], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.cudnn.pooling.forward", + ins[0], # x + outs[0], # y + 1, 0, # Alpha, beta + 3, # MODE: CUDNN_POOLING_MAX_DETERMINISTIC + nanProp_mode, + pool_size[0], pool_size[1], + padding[0], padding[1], + strides[0], strides[1] + ), + name="y", + ) + + s = te.create_schedule(cuDNN_OP.op) + func = tvm.build(s, [te_data, cuDNN_OP], "cuda -libs=cudnn", target_host="llvm") + + data = tvm.nd.array(neural_in["data"], dev) + output = tvm.nd.array(np.zeros(output_shape, dtype=dtype), dev) + + func(data, output) + + + elif op == "bn": + stat_shape = (1,in_channels,1,1) + te_data = te.placeholder(data_shape, name="data", dtype=dtype) + te_bn_gamma = te.placeholder(stat_shape, name="bn_gamma", dtype=dtype) + te_bn_beta = te.placeholder(stat_shape, name="bn_beta", dtype=dtype) + te_bn_mean = te.placeholder(stat_shape, name="bn_mean", dtype=dtype) + te_bn_var = te.placeholder(stat_shape, name="bn_var", dtype=dtype) + #axis + + + eps = 1e-5 + output_shape = data_shape + + + # BN mode + # CUDNN_BATCHNORM_PER_ACTIVATION(0): param dim should be 1xCxHxW: axis = 0 + # CUDNN_BATCHNORM_SPATIAL(1): param dim should be 1xCx1x1 axis = 1 + # CUDNN_BATCHNORM_SPATIAL_PERSISTENT(1): param dim should be 1xCx1x1 + + cuDNN_OP = te.extern( + output_shape, + [te_data, te_bn_gamma, te_bn_beta, te_bn_mean, te_bn_var], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.cudnn.batchnorm.forward", + 1, #MODE + ins[0], # x + outs[0], # y + ins[1], # scale = gamma + ins[2], # bias = beta + ins[3], # mean + ins[4], # var + 1, 0, # Alpha, beta + eps + ), + name="y", + ) + + s = te.create_schedule(cuDNN_OP.op) + func = tvm.build(s, [te_data, te_bn_gamma, te_bn_beta, te_bn_mean, te_bn_var, cuDNN_OP], "cuda -libs=cudnn", target_host="llvm") + + data = tvm.nd.array(neural_in["data"], dev) + gamma = tvm.nd.array(params["bn_gamma"].asnumpy().reshape(stat_shape), dev) + beta = tvm.nd.array(params["bn_beta"].asnumpy().reshape(stat_shape), dev) + mean = tvm.nd.array(params["bn_mean"].asnumpy().reshape(stat_shape), dev) + var = tvm.nd.array(params["bn_var"].asnumpy().reshape(stat_shape), dev) + output = tvm.nd.array(np.zeros(output_shape, dtype=dtype), dev) + + func(data, gamma, beta, mean, var, output) + else: + assert(0) + + + return output + + + + check_implementation("tvm.contrib.cudnn.batchnorm.forward") + check_implementation("tvm.contrib.cudnn.activation.forward") + check_implementation("tvm.contrib.cudnn.add") + check_implementation("tvm.contrib.cudnn.pooling.forward") + check_implementation("tvm.contrib.cudnn.reduce") + check_implementation("tvm.contrib.cudnn.scale") + check_implementation("tvm.contrib.cudnn.conv2d+bias+activation.forward") + check_implementation("tvm.contrib.cudnn.softmax.forward") + + + + return impl + + + +def ref_impl(config): + def impl(neural_in): + op = config["op"] + target = config["target"] + batch_size = config["batch_size"] + data_shape = config["data_shape"] + out_channels = config["out_channels"] + kernel_size = config["kernel_size"] + strides = config["strides"] + padding = config["padding"] + dilation = config["dilation"] + groups = config["groups"] + data_layout = config["data_layout"] + kernel_layout = config["kernel_layout"] + out_layout = config["out_layout"] + out_dtype = config["out_dtype"] + + key = genKey(config) + + # Define input tensor shapes and variables + data = relay.var("data", relay.TensorType(data_shape, "float32")) + weight = relay.var("weight") + bn_gamma = relay.var("bn_gamma") + bn_beta = relay.var("bn_beta") + bn_mmean = relay.var("bn_mean") + bn_mvar = relay.var("bn_var") + bias = relay.var("bias") + + #weight = relay.var("weight", shape=((out_channels, data_shape[1], kernel_size[0], kernel_size[1]))) + #bn_gamma = relay.var("bn_gamma", shape=(out_channels,)) + #bn_beta = relay.var("bn_beta", shape=(out_channels,)) + #bn_mmean = relay.var("bn_mean", shape=(out_channels,)) + #bn_mvar = relay.var("bn_var", shape=(out_channels,)) + + # Process given operators + if op == "conv2d": + simple_net = relay.nn.conv2d( + data = data, + weight = weight, + strides = strides, + padding = padding, + dilation = dilation, + groups = groups, + channels = out_channels, + kernel_size = kernel_size, + data_layout = data_layout, + kernel_layout = kernel_layout, + out_layout = out_layout, + out_dtype = "", + ) + elif op == "bn": + simple_net = relay.nn.batch_norm(data, bn_gamma, bn_beta, bn_mmean, bn_mvar)[0] + elif op == "relu": + simple_net = relay.nn.relu(data) + elif op == "conv2d+bn": + simple_net = relay.nn.conv2d( + data = data, + weight = weight, + strides = strides, + padding = padding, + dilation = dilation, + groups = groups, + channels = out_channels, + kernel_size = kernel_size, + data_layout = data_layout, + kernel_layout = kernel_layout, + out_layout = out_layout, + out_dtype = "", + ) + simple_net = relay.nn.batch_norm(simple_net, bn_gamma, bn_beta, bn_mmean, bn_mvar)[0] + elif op == "bn+relu": + simple_net = relay.nn.batch_norm(data, bn_gamma, bn_beta, bn_mmean, bn_mvar)[0] + simple_net = relay.nn.relu(simple_net) + elif op == "conv2d+bn+relu": + simple_net = relay.nn.conv2d( + data = data, + weight = weight, + strides = strides, + padding = padding, + dilation = dilation, + groups = groups, + channels = out_channels, + kernel_size = kernel_size, + data_layout = data_layout, + kernel_layout = kernel_layout, + out_layout = out_layout, + out_dtype = "", + ) + simple_net = relay.nn.batch_norm(simple_net, bn_gamma, bn_beta, bn_mmean, bn_mvar)[0] + simple_net = relay.nn.relu(simple_net) + + elif op == "conv2d+bias+relu": + simple_net = relay.nn.conv2d( + data = data, + weight = weight, + strides = strides, + padding = padding, + dilation = dilation, + groups = groups, + channels = out_channels, + kernel_size = kernel_size, + data_layout = data_layout, + kernel_layout = kernel_layout, + out_layout = out_layout, + out_dtype = "", + ) + simple_net = relay.nn.bias_add(simple_net, bias) + simple_net = relay.nn.relu(simple_net) + + + + # Create workload + inputs = relay.analysis.free_vars(simple_net) + + simple_net = relay.Function(inputs, simple_net) + + net, _ = testing.create_workload(simple_net) + + params = neural_in["params"] + # Bulid the subgraph + dev = tvm.device(target, 0) + lib = relay.build_module.build(net, target, params=params) + module = runtime.GraphModule(lib["default"](dev)) + + # Setup execution + module.set_input("data", neural_in["data"]) + + module.run() + # get output + out = module.get_output(0) + #print(out.asnumpy().shape) + return out + return impl + +configs = [ + + #["conv2d", "cuda -libs=cudnn", 1, (1,3,224,224), 16, (3,3), (1,1), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), -1, (1,16)], + ["conv2d+bias+relu", "cuda -libs=cudnn", 1, (1,3,224,224), 16, (3,3), (1,1), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), -1, (1,16)], + #["conv2d", "cuda -libs=cudnn", 1, (1,3,224,224), 16, (3,3), (2,2), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), -1,(1,16)], + #["conv2d+bias+relu", "cuda -libs=cudnn", 1, (1,3,224,224), 16, (3,3), (2,2), (1,1), (1,1), 1, "NCHW", "OIHW", "", "",(2,2), -1, (1,16)], + #["softmax", "cuda -libs=cudnn", 1, (1,1,4,4), 16, (3,3), (1,1), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), -1, (1,16)], + #["softmax", "cuda -libs=cudnn", 1, (1,1,4,4), 16, (3,3), (1,1), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), 0, (1,16)], + #["relu", "cuda -libs=cudnn", 1, (1,3,224,224), 16, (3,3), (1,1), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), -1, (1,16)], + #["biasadd", "cuda -libs=cudnn", 1, (13,31,224,12), 16, (3,3), (1,1), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), -1, (1,16)], + #["biasadd", "cuda -libs=cudnn", 1, (13,31,224,12), 16, (3,3), (1,1), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), 1, (1,31,224)], + #["biasadd", "cuda -libs=cudnn", 1, (13,31,224,12), 16, (3,3), (1,1), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), 0, (1,31,224)], + #["biasadd", "cuda -libs=cudnn", 1, (13,31,224,12), 16, (3,3), (1,1), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), 2, (1,31,224)], + #["maxpool2d", "cuda -libs=cudnn", 1, (1,1,4,4), 16, (3,3), (1,1), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), -1, (1,16)], + #["maxpool2d", "cuda -libs=cudnn", 1, (13,31,224,12), 16, (3,3), (1,1), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), -1, (1,16)], + #["maxpool2d", "cuda -libs=cudnn", 1, (224,224,112,12), 16, (2,2), (1,1), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), -1, (1,16)], + #["bn", "cuda -libs=cudnn", 1, (1,3,224,224), 16, (3,3), (1,1), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), 1, (1,16)], + #["bn", "cuda -libs=cudnn", 1, (11,23,24,24), 16, (3,3), (1,1), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), 1, (1,16)], + #["bn", "cuda -libs=cudnn", 1, (1,33,24,224), 16, (3,3), (1,1), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), 1, (1,16)], +] + +dicts = [] + +for config in configs: + # Define target backend and arguments for the subgraph + op, target, batch_size, data_shape, out_channels, kernel_size, \ + strides, padding, dilation, groups, data_layout, kernel_layout, \ + out_layout, out_dtype, pool_size, axis, bias_shape = config + d = {"op": op, "target": target, "batch_size": batch_size, "data_shape": data_shape, "out_channels": out_channels, "kernel_size": kernel_size, \ + "strides": strides, "padding": padding, "dilation": dilation, "groups": groups, "data_layout": data_layout, "kernel_layout": kernel_layout, \ + "out_layout": out_layout, "out_dtype": out_dtype, "bias_shape": bias_shape, "pool_size": pool_size, "axis": axis} + dicts.append(d) + +OUTPUT = "operator_cost.log" +# change this to test your ops! +CLIENT_IMPLEMENTATION = ref_tvm_build_cudnn +#CLIENT_IMPLEMENTATION = ref_impl +REPEAT = 1 +import json +logs = {} + +@pyt.mark.parametrize("config",dicts) +def test(config): + gt_network = get_gt_net(config) + + reference_implementation = CLIENT_IMPLEMENTATION(config) + + for i in range(REPEAT): + data = np.random.uniform(-1, 1, size=config["data_shape"]).astype("float32") + #bias = np.random.uniform(-1, 1, size=config["bias_shape"]).astype("float32") + net, params = testing.create_workload(gt_network) + + # Bulid the subgraph + dev = tvm.device(gt_target, 0) + lib = relay.build_module.build(net, gt_target, params=params) + module = runtime.GraphModule(lib["default"](dev)) + + # Setup execution + module.set_input("data", data) + + module.run() + # get output + out_gt = module.get_output(0).asnumpy() + out_impl = reference_implementation({"data": data, "params": params}).asnumpy() + + + assert(out_gt.shape == out_impl.shape) + assert(pyt.approx(out_impl, rel=1e-7, abs=1e-7) == out_gt) + +if __name__ == '__main__': + for config in dicts: + test(config) + + print("\n============= Completed ==================") diff --git a/python/tvm/relay/transform/backend_operator/utils.py b/python/tvm/relay/transform/backend_operator/utils.py index 9ce47bc05a16..39e453e28997 100644 --- a/python/tvm/relay/transform/backend_operator/utils.py +++ b/python/tvm/relay/transform/backend_operator/utils.py @@ -1,6 +1,7 @@ from tvm import relay from tvm.relay.dataflow_pattern import * from collections import namedtuple +import numpy as np from .pattern import Pattern @@ -18,11 +19,42 @@ def get_diamond(): # return the shape of input data to expr def get_data_shape(expr): inputs = relay.analysis.free_vars(expr) + # if is_call_node(expr): + # print(f"Input for expr ({expr.op}) {[inputs[0].type_annotation]}, {expr.attrs.axis}") + # for add, shape of lhs and rhs should be identical. for all other backend ops, we take shape of "data" input arg - data_shape_imm = inputs[0].type_annotation.shape - data_shape = tuple(map(lambda x: x.value, data_shape_imm)) + # inputs[0] corresponds to Var(name_hint='data') + # We consider two different types for that: TupleTypeNode, TensorTypeNode + if type(inputs[0].type_annotation) == relay.TensorType: + data_shape_imm = inputs[0].type_annotation.shape + data_shape = list(map(lambda x: x.value, data_shape_imm)) + elif type(inputs[0].type_annotation) == relay.TupleType: + data_shape = [] + for tup_item in inputs[0].type_annotation.fields: + data_shape.append(tuple((map(lambda x: x.value, tup_item.shape)))) + data_shape = tuple(data_shape) + print("data shape", data_shape) + else: + raise Exception(f"Unsupported Var type ({type(inputs[0].type_annotation)})") + return data_shape +def get_data(expr): + data_shape = get_data_shape(expr) + if type(data_shape) == list: + print(data_shape) + data = np.random.uniform(-1, 1, size=data_shape).astype("float32") + elif type(data_shape) == tuple: + data = [] + for shape in data_shape: + data.append(np.random.uniform(-1, 1, size=shape).astype("float32")) + print(shape) + data = tuple(data) + else: + raise Exception(f"Unsupported data shape type {type(data_shape)}") + + return data + def is_function_node(expr): return type(expr) == tvm.relay.Function @@ -32,12 +64,15 @@ def is_constant_node(expr): def is_call_node(expr): return type(expr) == tvm.relay.expr.Call +def is_tuple_node(expr): + return type(expr) == tvm.relay.expr.Tuple + def is_tuplegetitem_node(expr): return type(expr) == tvm.relay.expr.TupleGetItem def is_call_or_tuplegetitem_node(expr): # If not, it means that we need to add codes to deal with other nodes - assert is_call_node(expr) or is_tuplegetitem_node(expr) or is_var_node(expr) or is_constant_node(expr) + assert is_call_node(expr) or is_tuplegetitem_node(expr) or is_var_node(expr) or is_constant_node(expr) or is_tuple_node(expr) return is_call_node(expr) or is_tuplegetitem_node(expr) def is_var_node(expr): @@ -52,9 +87,10 @@ def get_attr_vals(expr): attrs = expr.attrs op_name = expr.op.name - if attrs == None: + if attrs == None or "keys" not in dir(attrs): return (op_name, "") + # print(f"{expr.op}'s Attrs : {dir(attrs)}") keys = attrs.keys() values = [] for key in keys: diff --git a/python/tvm/relay/transform/optimizer/_optimizer.py b/python/tvm/relay/transform/optimizer/_optimizer.py index 7357ed43a7fe..ec303b6e1ee6 100644 --- a/python/tvm/relay/transform/optimizer/_optimizer.py +++ b/python/tvm/relay/transform/optimizer/_optimizer.py @@ -12,6 +12,7 @@ from .comp_graph import ComputationGraph from .comp_graph_optimizer import CompGraphOptimizer +from ..utility.visualize import visualize_network from .optimizer_utils import print_matching_final def setup_backend_op_lib(network_expr, targets, batch_size): @@ -39,6 +40,7 @@ def optimize_comp_graph(relay_expr): # It is a function if you get it from last pass of Relay build # print("Relay expression") # print(relay_expr) + # visualize_network(relay_expr, "nasneta_opt") if type(relay_expr) == tvm.relay.function.Function: relay_expr = relay_expr.body @@ -48,12 +50,14 @@ def optimize_comp_graph(relay_expr): # target_backend = None # Consider all targets # targets = [Target.TENSORRT, Target.CUDNN, Target.CUBLAS, Target.TVM_GPU_NO_TUNING, Target.TVM_CPU] - # targets = [Target.TVM_GPU_NO_TUNING, Target.TVM_GPU] # targets = [Target.TVM_GPU_AUTOSCH] # targets = [Target.TENSORRT] # targets = [Target.CUDNN, Target.TVM_GPU_NO_TUNING] # targets = [Target.TENSORRT, Target.TVM_GPU_NO_TUNING] - targets = [Target.TVM_GPU_NO_TUNING, Target.TENSORRT, Target.CUDNN] + # targets = [Target.TVM_GPU_NO_TUNING, Target.TVM_GPU_AUTOSCH, Target.TENSORRT, Target.CUDNN]#, Target.CUBLAS] + # targets = [Target.CUDNN] + # targets = [Target.TVM_GPU_NO_TUNING] + targets = [Target.TVM_GPU_NO_TUNING, Target.CUDNN] # , Target.CUBLAS] batch_size = 1 backendop_lib = setup_backend_op_lib(relay_expr, targets, batch_size) diff --git a/python/tvm/relay/transform/optimizer/comp_graph.py b/python/tvm/relay/transform/optimizer/comp_graph.py index 7ab4fd876052..487cd8bfa916 100644 --- a/python/tvm/relay/transform/optimizer/comp_graph.py +++ b/python/tvm/relay/transform/optimizer/comp_graph.py @@ -1,6 +1,7 @@ import tvm from tvm import relay -from ..backend_operator.utils import is_call_node, is_tuplegetitem_node, is_var_node, is_constant_node, is_function_node +from ..backend_operator.utils import * +from .optimizer_utils import is_data_var_node class Node: def __init__(self, relay_expr, topological_order): @@ -59,14 +60,20 @@ def get_n_nodes(self): def _get_n_nodes(self, relay_expr): self._memo[hash(relay_expr)] = True n_nodes = 1 - if is_constant_node(relay_expr) or (is_var_node(relay_expr) and relay_expr.name_hint != 'data'): + if is_constant_node(relay_expr) or (is_var_node(relay_expr) and not is_data_var_node(relay_expr)): n_nodes = 0 - elif is_var_node(relay_expr) and relay_expr.name_hint == 'data': + elif is_var_node(relay_expr) and is_data_var_node(relay_expr): n_nodes = 1 elif is_tuplegetitem_node(relay_expr): next_expr = relay_expr.tuple_value if hash(next_expr) not in self._memo: n_nodes += self._get_n_nodes(next_expr) + elif is_tuple_node(relay_expr): + for node_idx, node in enumerate(relay_expr.fields): + if hash(node) not in self._memo: + # memorize this visit to prevent it from visiting twice + # +1 here means counting the current node + n_nodes += self._get_n_nodes(node) elif is_call_node(relay_expr): for node_idx, node in enumerate(relay_expr.args): if hash(node) not in self._memo: @@ -74,7 +81,7 @@ def _get_n_nodes(self, relay_expr): # +1 here means counting the current node n_nodes += self._get_n_nodes(node) else: - raise Exception("Unexpected Relay expr type") + raise Exception(f"Unexpected Relay expr type {type(relay_expr)}") return n_nodes @@ -89,13 +96,25 @@ def _add_node(self, relay_expr, topological_order, parent_expr): self._memo[hash(relay_expr)] = True def _expr2graph(self, relay_expr, topological_order, parent_expr): - if is_constant_node(relay_expr) or (is_var_node(relay_expr) and relay_expr.name_hint != 'data'): + if is_constant_node(relay_expr) or (is_var_node(relay_expr) and not is_data_var_node(relay_expr)): return else: self._add_node(relay_expr, topological_order, parent_expr) - if is_var_node(relay_expr) and relay_expr.name_hint == 'data': + if is_var_node(relay_expr) and is_data_var_node(relay_expr): return + elif is_tuple_node(relay_expr): + for node_idx, node in enumerate(relay_expr.fields): + if hash(node) not in self._memo: + # memorize this visit to prevent it from visiting twice + # +1 here means counting the current node + self._expr2graph(node, topological_order + 1, relay_expr) + else: + # Make sure the node has a right (larger) topological order + # if there are multiple choices + if self.expr2node[hash(node)]._topological_order < topological_order: + self.expr2node[hash(node)]._topological_order = topological_order + self.expr2node[hash(node)].add_parent(relay_expr) elif is_tuplegetitem_node(relay_expr): # If it is tuple, you should use tuple_value instead of args next_expr = relay_expr.tuple_value diff --git a/python/tvm/relay/transform/optimizer/comp_graph_optimizer.py b/python/tvm/relay/transform/optimizer/comp_graph_optimizer.py index 3b19251a565d..51d541bc1307 100644 --- a/python/tvm/relay/transform/optimizer/comp_graph_optimizer.py +++ b/python/tvm/relay/transform/optimizer/comp_graph_optimizer.py @@ -26,7 +26,7 @@ def match(self, expr): self._memo = {} self._optimized_match = {} - dummy_annotation = (9999999, "NO_PYTHON_OP") + dummy_annotation = (9999999, "PYTHON_INVALID_BACKEND_OP") self.visit_expr(expr, dummy_annotation) return self._optimized_match, self._topo_order_to_op @@ -67,6 +67,13 @@ def visit_expr(self, expr, annotation): # is_leaf_node = is_constant_node(expr) or is_var_node(expr) # if not is_leaf_node: if expr not in self._optimized_match: + # annotation[0] -> group number, annotation[1] -> backend op name + # Func(Relu2(Conv2(Func(Relu1(Conv1))))) + # Dictionary + # Conv1 : 0, tensrort_fused_conv + # Relu1 : 0, tensrort_fused_conv + # Conv2 : 1, tensrort_fused_conv + # Relu2 : 1, tensrort_fused_conv self._optimized_match[expr] = f"{annotation[0]}-{annotation[1]}" self._topo_order_to_op.append((node_type, self._optimized_match[expr])) else: @@ -172,33 +179,49 @@ def optimize(self, comp_graph): pair2match = {} self.loc2match = {hash(comp_graph.get_root()): {"match":[], "cost":0, "string":""}} while not frontiers.empty(): + # Facilitate the debugging process + self._backendop_lib.save_to_log() f = frontiers.get() f_expr = f.get_relay_expr() - print("Topologicla order : ", f._topological_order) + if is_call_node(f_expr): + print(f"(topo_order, op_type) : {f._topological_order}, {f_expr.op}") + else: + print(f"(topo_order, op_type) : {f._topological_order}, {f_expr}, Non-call node") + + # print(self._backendop_lib.get_all_patterns()) for pat in self._backendop_lib.get_all_patterns(): # print(pat) if pat.get_pattern().match(f_expr): # Check if there is an existing frontier with the same goal idx + # Conv(Data, Weight) + # get_next_expr_after_match -> [Data, Weight] + # next_expr_after_match = Conv() assert get_pattern_len(pat.get_pattern()) >= 1 tuple_after_matches = get_next_expr_after_match(f_expr, None, get_pattern_len(pat.get_pattern())) - + print("PATTERN MATCHED", pat.get_pattern()) # Consdier only valid nodes tuple_after_matches = [tup for tup in tuple_after_matches if hash(tup[0]) in comp_graph.expr2node] for t_idx, (expr_after_match, prev_expr_after_match) in enumerate(tuple_after_matches): # Get new frontier, matched backend ops, and their costs new_loc = comp_graph.expr2node[hash(expr_after_match)] pat_op, pat_cost = get_optimal_backendop(self._backendop_lib, f_expr, pat, self._target_backend) + + # new_match = self.loc2match[hash(f)]["match"] + [(pat_op, pat_cost, hash(f_expr))] + # new_cost = self.loc2match[hash(f)]["cost"] + pat_cost + # new_string = self.loc2match[hash(f)]['string'] + "-" + self._pattern_to_name[pat] + # Flush matchings from second branch if there are more than one branches if t_idx == 0: - new_match = self.loc2match[hash(f)]["match"] + [(pat_op, pat_cost, hash(f))] + new_match = self.loc2match[hash(f)]["match"] + [(pat_op, pat_cost, hash(f_expr))] new_cost = self.loc2match[hash(f)]["cost"] + pat_cost new_string = self.loc2match[hash(f)]['string'] + "-" + self._pattern_to_name[pat] + print(f"Assign matched op : {pat_op}") else: new_match, new_cost, new_string = [], 0, "+" # Maintain pair2match for keeping track of match results for each branch new_loc.matched_expr[hash(prev_expr_after_match)] = 1 - out_key = hash(new_loc) + out_key = hash(new_loc) # new_loc is node in_key = hash(prev_expr_after_match) if out_key not in pair2match: @@ -331,4 +354,4 @@ def get_optimized_match(self, comp_graph): - \ No newline at end of file + diff --git a/python/tvm/relay/transform/optimizer/optimizer_utils.py b/python/tvm/relay/transform/optimizer/optimizer_utils.py index f0ef619f9b0f..e8f69a26a0e7 100644 --- a/python/tvm/relay/transform/optimizer/optimizer_utils.py +++ b/python/tvm/relay/transform/optimizer/optimizer_utils.py @@ -1,11 +1,22 @@ import tvm from ..backend_operator.utils import is_call_node, is_tuplegetitem_node, is_var_node, is_constant_node, is_function_node +DATA_NAME_HINTS = ['data', 'input', 'x'] + +def is_data_var_node(expr): + is_data_var = False + for data_name_hint in DATA_NAME_HINTS: + if data_name_hint in expr.name_hint: + is_data_var = True + break + return is_data_var + + def get_next_expr_after_match(relay_expr, prev_relay_expr, depth): target_node = [] if type(relay_expr) == tvm.relay.expr.Var: - if relay_expr.name_hint == 'data': + if is_data_var_node(relay_expr): return [(relay_expr, prev_relay_expr)] return [(None, prev_relay_expr)] elif is_constant_node(relay_expr): @@ -19,6 +30,11 @@ def get_next_expr_after_match(relay_expr, prev_relay_expr, depth): if type(relay_expr) == tvm.relay.expr.TupleGetItem: target_node += get_next_expr_after_match(relay_expr.tuple_value, relay_expr, depth-1) else: + # Note that batch_matmul also has args + # if type(relay_expr) == tvm.relay.nn.batch_matmul: + # target_node += get_next_expr_after_match(relay_expr.x, relay_expr, depth - 1) + # target_node += get_next_expr_after_match(relay_expr.y, relay_expr, depth - 1) + # else: for node in relay_expr.args: target_node += get_next_expr_after_match(node, relay_expr, depth-1) # # FIX: Hacky way to avoid residual connection @@ -35,6 +51,11 @@ def get_pattern_len(pattern): elif type(pattern) == tvm.relay.dataflow_pattern.TupleGetItemPattern: length = get_pattern_len(pattern.tuple) length += 1 + elif type(pattern) == tvm.relay.dataflow_pattern.TuplePattern: + for child in pattern.fields: + print(type(child)) + length = max(length, get_pattern_len(child)) + length += 1 return length @@ -73,4 +94,4 @@ def print_matching_debug(comp_graph, loc2match): for item in loc2match[hash(comp_graph._nodes[idx])]["match"][::-1]: op_name, op_cost = item - print(f"({op_name}, {op_cost:.2g})") \ No newline at end of file + print(f"({op_name}, {op_cost:.2g})") diff --git a/python/tvm/relay/transform/optimizer/test_ops.py b/python/tvm/relay/transform/optimizer/test_ops.py index 89f238bc4a62..c87557fe7a91 100644 --- a/python/tvm/relay/transform/optimizer/test_ops.py +++ b/python/tvm/relay/transform/optimizer/test_ops.py @@ -295,10 +295,11 @@ def impl(neural_in): params = neural_in["params"] opt_level = 2 + target_str = 'cuda'# -libs=cudnn' with tvm.transform.PassContext(opt_level=opt_level): - lib = relay.build(simple_net, tvm.target.cuda(), params=params) + lib = relay.build(simple_net, target_str, params=params)#tvm.target.cuda(), params=params) - dev = tvm.device("cuda", 0) + dev = tvm.device(target_str, 0)#"cuda", 0) #dev = tvm.device("cuda -libs=cudnn", 0) #lib = relay.build_module.build(simple_net, "cuda") mod = runtime.GraphModule(lib["default"](dev)) @@ -803,11 +804,13 @@ def impl(neural_in): configs = [ - #["conv2d", "cuda -libs=cudnn", 1, (1,3,224,224), 16, (3,3), (1,1), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), -1, (1,16)], - ["conv2d+bias+relu", "cuda -libs=cudnn", 1, (1,3,224,224), 16, (3,3), (1,1), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), -1, (1,16)], + ["conv2d", "cuda -libs=cudnn", 1, (1,3,224,224), 16, (3,3), (1,1), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), -1, (1,16)], + #["conv2d+bias+relu", "cuda -libs=cudnn", 1, (1,3,224,224), 16, (3,3), (1,1), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), -1, (1,16)], #["conv2d", "cuda -libs=cudnn", 1, (1,3,224,224), 16, (3,3), (2,2), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), -1,(1,16)], #["conv2d+bias+relu", "cuda -libs=cudnn", 1, (1,3,224,224), 16, (3,3), (2,2), (1,1), (1,1), 1, "NCHW", "OIHW", "", "",(2,2), -1, (1,16)], - #["softmax", "cuda -libs=cudnn", 1, (1,1,4,4), 16, (3,3), (1,1), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), -1, (1,16)], + #["conv2d+bias+relu", "cuda -libs=cudnn", 1, (1,3,224,224), 16, (3,3), (2,2), (1,1), (1,1), 1, "NCHW", "OIHW", "", "",(2,2), -1, (1,16)], + + # ["softmax", "cuda -libs=cudnn", 1, (1,1,4,4), 16, (3,3), (1,1), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), -1, (1,16)], #["softmax", "cuda -libs=cudnn", 1, (1,1,4,4), 16, (3,3), (1,1), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), 0, (1,16)], #["relu", "cuda -libs=cudnn", 1, (1,3,224,224), 16, (3,3), (1,1), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), -1, (1,16)], #["biasadd", "cuda -libs=cudnn", 1, (13,31,224,12), 16, (3,3), (1,1), (1,1), (1,1), 1, "NCHW", "OIHW", "", "", (2,2), -1, (1,16)], @@ -836,6 +839,7 @@ def impl(neural_in): OUTPUT = "operator_cost.log" # change this to test your ops! +# CLIENT_IMPLEMENTATION = ref_tvm_op_build_cudnn CLIENT_IMPLEMENTATION = ref_tvm_build_cudnn #CLIENT_IMPLEMENTATION = ref_impl REPEAT = 1 diff --git a/python/tvm/relay/transform/optimizer/test_relay_build.py b/python/tvm/relay/transform/optimizer/test_relay_build.py index 6d4a43cd5011..807263501295 100644 --- a/python/tvm/relay/transform/optimizer/test_relay_build.py +++ b/python/tvm/relay/transform/optimizer/test_relay_build.py @@ -7,6 +7,36 @@ from ..workloads.onnx_workloads import get_network_from_onnx from ..workloads.torch_workloads import get_network_from_torch +def get_concat(): + data1 = relay.var("data1", relay.TensorType((1, 3, 224, 224), "float32")) + data2 = relay.var("data2", relay.TensorType((1, 3, 224, 224), "float32")) + data3 = relay.var("data3", relay.TensorType((1, 3, 224, 224), "float32")) + data4 = relay.var("data4", relay.TensorType((1, 3, 224, 224), "float32")) + tup = relay.Tuple([data1, data2, data3, data4]) + return tup + #concat = relay.concatenate(tup, axis=1) + #return concat + +def get_conv(): + # Chain graph + out_channels = 16 + batch_size = 1 + + data = relay.var("data", relay.TensorType((batch_size, 3, 224, 224), "float32")) + conv_weight = relay.var("weight", relay.TensorType((out_channels, 3, 3, 3), "float32")) + # bn_gamma = relay.var("bn_gamma") + # bn_beta = relay.var("bn_beta") + # bn_mmean = relay.var("bn_mean") + # bn_mvar = relay.var("bn_var") + + # simple_net = relay.nn.relu(data) + # simple_net = relay.nn.relu(simple_net) + simple_net = relay.nn.conv2d( + data=data, weight=conv_weight, kernel_size=(3, 3), channels=out_channels, padding=(1, 1) + ) + + return simple_net + def get_chain_graph(): # Chain graph out_channels = 16 @@ -75,5 +105,16 @@ def build_network(net, params): # network_name = "resnet50" # mod, params, _, _ = get_network_from_onnx(network_name, batch_size=1) -mod, params, _, _ = get_network_from_torch("resnet50",1) +# mod, params, _, _ = get_network_from_torch("resnet_block", 1) +# mod, params, _, _ = get_network_from_torch("resnet50", 1) +# mod, params, _, _ = get_network_from_torch("resnext50_32x4d",1) +# mod, params, _, _ = get_network_from_torch("bert",1) +# mod, params, _, _ = get_network_from_torch("nasrnn",1) + +# print(get_concat().attrs.axis) + +from tvm.relay.dataflow_pattern import * +print(is_tuple([wildcard(), wildcard(), wildcard(), wildcard()]).match(get_concat())) + +mod, params, _, _ = get_network_from_torch("nasneta",1) build_network(mod["main"], params) diff --git a/python/tvm/relay/transform/workloads/onnx_workloads.py b/python/tvm/relay/transform/workloads/onnx_workloads.py index 74ccdd35187a..40e96c51d2f1 100644 --- a/python/tvm/relay/transform/workloads/onnx_workloads.py +++ b/python/tvm/relay/transform/workloads/onnx_workloads.py @@ -1,5 +1,5 @@ from tvm import relay -import tensorflow as tf +# import tensorflow as tf import os import onnx diff --git a/python/tvm/relay/transform/workloads/torch_workloads.py b/python/tvm/relay/transform/workloads/torch_workloads.py index 5927ef0d4b82..5591a3a80d58 100644 --- a/python/tvm/relay/transform/workloads/torch_workloads.py +++ b/python/tvm/relay/transform/workloads/torch_workloads.py @@ -5,12 +5,13 @@ import copy from .workloads import WORKLOADS_DIC -from ..baselines.pytorch.resnets import resnet50, resnext50_32x4d +from ..baselines.pytorch.resnets import resnet50, resnext50_32x4d, resnet_block from ..baselines.pytorch.nasnet_a import NASNetA from ..baselines.pytorch.nasrnn import NASRNN from ..baselines.pytorch.bert import BERT NETWORK_TO_TORCH_MODEL = { + "resnet_block": resnet_block, "resnet50" : resnet50, "resnext50_32x4d" : resnext50_32x4d, "nasneta" : NASNetA, @@ -20,7 +21,11 @@ def load_torch_model(name): # Get the model - model = NETWORK_TO_TORCH_MODEL[name]().cuda() + if name == "nasrnn": + model = NETWORK_TO_TORCH_MODEL[name](is_gpu=False)#.cuda() + else: + model = NETWORK_TO_TORCH_MODEL[name]() # .cuda() + model.eval() # Create the input data diff --git a/python/tvm/relay/transform/workloads/workloads.py b/python/tvm/relay/transform/workloads/workloads.py index 3dcdc3ea6f40..7aa5877e716c 100644 --- a/python/tvm/relay/transform/workloads/workloads.py +++ b/python/tvm/relay/transform/workloads/workloads.py @@ -1,8 +1,9 @@ # Value is shape dict WORKLOADS_DIC = { + "resnet_block" : {"input0": [1, 64, 56, 56]}, "resnet50" : {"input0": [1, 64, 56, 56]}, "resnext50_32x4d" : {"input0": [1, 64, 56, 56]}, - "nasneta" : {"input0": [1, 128, 56, 56]}, + "nasneta" : {"input0": [1, 64, 56, 56]}, "nasrnn": {'x.1': [1, 512]}, # "nasrnn": {'x.1': [1, 512], 'x.2': [1, 512], 'x.3': [1, 512], 'x.4': [1, 512], 'x': [1, 512]}, "bert": {"input0": [64, 1024]}, diff --git a/python/tvm/topi/cuda/bias_add.py b/python/tvm/topi/cuda/bias_add.py index b3ef3461385b..278236a19761 100644 --- a/python/tvm/topi/cuda/bias_add.py +++ b/python/tvm/topi/cuda/bias_add.py @@ -40,12 +40,12 @@ def schedule_bias_add(outs): return schedule_injective(outs) -def bias_add_cudnn(data, bias, axis=1): +def biasadd_cudnn(data, bias, axis=1): """Perform bias_add on the data using cudnn""" print("Python topi cuda cudnn bias_add!!") - return cudnn.bias_add(data, bias, axis) + return cudnn.biasadd(data, bias, axis) -def schedule_bias_add_cudnn(outs): +def schedule_biasadd_cudnn(outs): """Schedule for softmax cudnn op""" return generic.schedule_extern(outs) diff --git a/python/tvm/topi/cuda/conv2d.py b/python/tvm/topi/cuda/conv2d.py index 63c7c9308284..ce50afc64314 100644 --- a/python/tvm/topi/cuda/conv2d.py +++ b/python/tvm/topi/cuda/conv2d.py @@ -135,6 +135,79 @@ def conv2d_cudnn( ) +@autotvm.register_topi_compute("conv2d_biasadd_relu_cudnn.cuda") +def conv2d_biasadd_relu_cudnn( + cfg, data, kernel, z, bias, strides, padding, dilation, groups=1, layout="NCHW", out_dtype="float32" +): + """Compute conv2d using CuDNN library""" + if layout == "NCHW": + tensor_format = 0 # CUDNN_TENSOR_NCHW + N, _, H, W = get_const_tuple(data.shape) + elif layout == "NHWC": + tensor_format = 1 # CUDNN_TENSOR_NHWC + N, H, W, _ = get_const_tuple(data.shape) + else: + raise ValueError("Unsupported layout %s in cudnn" % layout) + CO, CI, KH, KW = get_const_tuple(kernel.shape) + + # handle dilation + stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides + dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation + + if ( + isinstance(padding, (list, tuple)) + and len(padding) == 4 + and (padding[0] != padding[2] or padding[1] != padding[3]) + ): + raise ValueError("Cudnn doesn't support asymmetric padding.") + pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW)) + OH = (H + pt + pb - KH) // stride_h + 1 + OW = (W + pl + pr - KW) // stride_w + 1 + + if isinstance(N, int): + cfg.add_flop( + groups + * 2 + * N + * OH + * OW + * CO + * CI + * ((KH - 1) * dilation_h + 1) + * ((KW - 1) * dilation_w + 1) + ) + + if data.dtype == "int8" or kernel.dtype == "int8": + if layout == "NCHW": + raise ValueError("NCHW layout do not support int8 in cudnn") + dtype = "int32" + else: + dtype = data.dtype + + cfg.define_knob("algo", range(8)) + if cfg.is_fallback: # Let CUDNN choose the best algo + cfg["algo"] = OtherOptionEntity(-1) + + return cudnn.conv2d_biasadd_relu( + data, + kernel, + z, + bias, + [pt, pl], # cudnn padding pt, pl on both sides of input + [stride_h, stride_w], + [dilation_h, dilation_w], + conv_mode=1, + tensor_format=tensor_format, + algo=cfg["algo"].val, + conv_dtype=dtype, + activ_mode=1, + nan_prop_mode=0, + actv_coeff=1e100, + groups=groups, + ) + + + @autotvm.register_topi_schedule("conv2d_cudnn.cuda") def schedule_conv2d_cudnn(cfg, outs): """Create the schedule for conv2d_cudnn""" diff --git a/python/tvm/topi/cuda/pooling.py b/python/tvm/topi/cuda/pooling.py index f2a6aadb659f..b635174bbe90 100644 --- a/python/tvm/topi/cuda/pooling.py +++ b/python/tvm/topi/cuda/pooling.py @@ -20,6 +20,7 @@ from tvm import te from .. import tag from ..utils import traverse_inline +from tvm.contrib import cudnn def schedule_adaptive_pool(outs, layout="NCHW"): @@ -201,3 +202,9 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s + +def maxpool2d_cudnn(x, pool_size, strides, padding, pool_type, ceil_mode, data_layout, count_include_pad): + """Perform max pool2d on the data using cudnn""" + return cudnn.maxpool2d(x, pool_size, strides, padding, pool_type, ceil_mode, data_layout, count_include_pad) + + diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 0777b19ec557..0a576933d2b3 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -45,6 +45,7 @@ #include #include +#include #include "../transforms/pass_utils.h" #include "utils.h" @@ -104,15 +105,21 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); } + + tvm::runtime::Optional dp_info; + CachedFunc Create(const Function& prim_func) { auto cache_node = make_object(); cache_node->target = target_; + + Array all_inputs; for (Var param : prim_func->params) { Array inputs; if (const auto* ttype = param->checked_type().as()) { tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); cache_node->inputs.push_back(tensor); inputs.push_back(tensor); + all_inputs.push_back(tensor); } else { // flatten tuple of tensor type. const auto* tuple_type = param->type_as(); @@ -123,12 +130,36 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); cache_node->inputs.push_back(tensor); inputs.push_back(tensor); + all_inputs.push_back(tensor); } } memo_[param] = inputs; } readable_name_stream_ << "fused"; - cache_node->outputs = this->VisitExpr(prim_func->body); + + dp_info = prim_func->GetAttr(attr::kBackendOp); + std::string dp_target = ""; + if(dp_info!=nullptr) dp_target = std::string(dp_info.value()); + + //NOTE: update target --> Target("llvm") + std::cerr << "DP_TARGET: " << dp_target << "\n"; + bool doCustomLowering = dp_target.size()>0 + && ((int)dp_target.find("INVALID_BACKEND_OP")==-1) + && ((int)dp_target.find("tvmgpu")==-1); + + if(doCustomLowering){ + // Note: Sung + cache_node->outputs = myVisitExpr(prim_func, dp_target); + + assert(0); + + }else{ + std::cerr << ">> Regular Path\n"; + std::cerr << prim_func << "\n"; + cache_node->outputs = this->VisitExpr(prim_func->body); + } + + auto candidate_name = readable_name_stream_.str(); constexpr static size_t kMaxFuncNameLength = 80; if (candidate_name.size() > kMaxFuncNameLength) { @@ -164,17 +195,21 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> } // Use TOPI schdule if user specificed, or the function has no auto_scheduler schedule. - if (!schedule.defined()) { - ICHECK(anchor_implementation_.defined()); - schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); - } + //if(dp_target.size()==0){ + if(!schedule.defined()){ + ICHECK(anchor_implementation_.defined()); + schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); + } for (const auto& scalar : scalars_) { if (schedule->Contain(scalar)) { schedule[scalar].compute_inline(); } } + //} } cache_node->schedule = std::move(schedule); + + return CachedFunc(cache_node); } @@ -211,11 +246,101 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> return {value}; } - Array VisitExpr_(const CallNode* call_node) final { + + void collectInputs(Map>& iMap, const Expr& expr){ + const CallNode* call_node = static_cast(expr.get()); + ICHECK(call_node); + + int count_tuple = 0; + for (Expr arg : call_node->args) { + if (arg->checked_type().as()) { + ++count_tuple; + } + + if(static_cast(arg.get())) + collectInputs(iMap, arg); + + Array inputs; + for (te::Tensor tensor : VisitExpr(arg)) { + inputs.push_back(tensor); + } + iMap.Set(arg, inputs); + + } + + if (count_tuple) { + ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; + } + } + + + Array myVisitExpr(const Function& prim_func, const std::string dp_target){ + const CallNode* call_node = static_cast(prim_func->body.get()); + + using tir::make_const; static auto fpattern = Op::GetAttrMap("TOpPattern"); static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); ICHECK(flower_call) << "relay.backend.lower_call is not registered."; + + Map> inputMap; + collectInputs(inputMap, prim_func->body); + + //for(auto arg:inputMap) + // for(auto v:arg.second) + // std::cerr << arg.first << " ==> " << v << "\n"; + + ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; + Op op = Downcast(call_node->op); + + Array outputs; + OpImplementation impl; + + std::cerr << ">> Sung's intercept\n"; + static auto ftarget_specific_lower_call = tvm::runtime::Registry::Get("relay.backend.target_specific_lowering"); + LoweredOutput lowered_out = (*ftarget_specific_lower_call)(prim_func, inputMap, dp_target); + //LoweredOutput lowered_out = (*ftarget_specific_lower_call)(prim_func, inputs); + outputs = lowered_out->outputs; + impl = lowered_out->implementation; + + int op_pattern = fpattern[op]; + if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { + ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) + << "Cannot apply TOPI schedule to a primitive function with two complicated ops" + << " anchor=" << anchor_op_ << " current=" << op; + } + if (op_pattern >= anchor_op_pattern_) { + anchor_op_ = op; + anchor_attrs_ = call_node->attrs; + anchor_op_pattern_ = op_pattern; + anchor_implementation_ = impl; + } + + if (outputs.size() != 1) { + const auto* tuple_type = call_node->checked_type().as(); + ICHECK(tuple_type) << "Expect output to be a tuple type"; + ICHECK_EQ(tuple_type->fields.size(), outputs.size()); + } + + + // Set the name to `__copy`. It will be detected in graph executor to perform + // data copy across devices. + if (op == device_copy_op_) { + readable_name_stream_.str(std::string()); + readable_name_stream_ << "__copy"; + } else { + readable_name_stream_ << '_' << op->name; + } + return outputs; + } + + Array VisitExpr_(const CallNode* call_node) final { + + using tir::make_const; + static auto fpattern = Op::GetAttrMap("TOpPattern"); + static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); + ICHECK(flower_call) << "relay.backend.lower_call is not registered."; + Array inputs; int count_tuple = 0; for (Expr arg : call_node->args) { @@ -240,7 +365,12 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> const auto* copy_input = inputs[0].operator->(); outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0)); } else { + std::string dp_target = ""; + if(dp_info!=nullptr) + dp_target = std::string(dp_info.value()); + LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); + outputs = lowered_out->outputs; impl = lowered_out->implementation; } @@ -274,11 +404,13 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> } Array VisitExpr_(const FunctionNode* op) final { + using tir::make_const; LOG(FATAL) << "Do not support sub function"; return Array(); } Array VisitExpr_(const LetNode* op) final { + using tir::make_const; Array val = VisitExpr(op->value); ICHECK(!memo_.count(op->var)); memo_[op->var] = val; @@ -286,6 +418,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> } Array VisitExpr_(const TupleNode* op) final { + using tir::make_const; Array fields; for (Expr field : op->fields) { ICHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; @@ -297,6 +430,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> } Array VisitExpr_(const TupleGetItemNode* op) final { + using tir::make_const; const auto* tuple_type = op->tuple->type_as(); Array tuple = VisitExpr(op->tuple); ICHECK_EQ(tuple_type->fields.size(), tuple.size()); @@ -374,7 +508,9 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> } readable_name_stream_ << "shape_func"; auto cache_node = make_object(); + cache_node->outputs = VisitExpr(prim_func->body); + auto candidate_name = readable_name_stream_.str(); constexpr static size_t kMaxFuncNameLength = 80; if (candidate_name.size() > kMaxFuncNameLength) { @@ -747,6 +883,8 @@ class CompileEngineImpl : public CompileEngineNode { auto cfunc = CreateSchedule(key->source_func, key->target); auto cache_node = make_object(*(cfunc.operator->())); + //std::cerr << "@@@ Schedule is creatd\n"; + // Skip lowering for device copy node. const Expr body = (key->source_func)->body; if (const CallNode* call_node = body.as()) { @@ -762,6 +900,8 @@ class CompileEngineImpl : public CompileEngineNode { for (te::Tensor arg : cache_node->outputs) { all_args.push_back(arg); } + + //std::cerr << "@@@ Lower the function\n"; // lower the function if (const auto* f = runtime::Registry::Get("relay.backend.lower")) { cache_node->funcs = (*f)(cfunc->schedule, all_args, cache_node->func_name, key->source_func); diff --git a/src/runtime/contrib/cudnn/add.cc b/src/runtime/contrib/cudnn/add.cc index 75650f79fb6d..7e20744f9df5 100644 --- a/src/runtime/contrib/cudnn/add.cc +++ b/src/runtime/contrib/cudnn/add.cc @@ -40,6 +40,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.add") int axis = args[4]; int ndim = x->ndim; + assert(ndim==2); + int64_t* shape = x->shape; if (axis < 0) axis += ndim; ICHECK(axis >= 0 && axis < ndim);