Skip to content

Commit

Permalink
Merge pull request apache#2 from cmu-catalyst/main
Browse files Browse the repository at this point in the history
CuDNN op impl to DP fused pass
  • Loading branch information
MadFunMaker committed Apr 8, 2021
2 parents 8d343ec + 5e9242c commit 15ddf4b
Show file tree
Hide file tree
Showing 14 changed files with 1,359 additions and 5 deletions.
82 changes: 82 additions & 0 deletions python/tvm/contrib/cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,3 +469,85 @@ def softmax(x, axis=-1):
),
name="y",
)

#inputs[0], pool_size, strides, padding, "max", ceil_mode, data_layout, True)

def max_pool2d(x, pool_size, strides, padding, pool_type, ceil_mode, data_layout, count_include_pad):
"""Compute softmax using CuDNN
Parameters
----------
x : tvm.te.Tensor
The input tensor
axis : int
The axis to compute the softmax
Returns
-------
ret : tvm.te.Tensor
The result tensor
"""
"""
double double_alpha = args[2];
double double_beta = args[3];
int mode = args[4];
int nanOpt = args[5];
int windowHeight = args[6];
int windowWidth = args[7];
int verticalPadding = args[8];
int horizontalPadding = args[9];
int verticalStride = args[10];
int horizontalStride = args[11];
"""
print("Python cudnn.py pool2d!!", file=sys.stderr)
return te.extern(
x.shape,
[x],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cudnn.pooling.forward", ins[0], outs[0],
1.0, 0.0, 0, 0, pool_size[0], pool_size[1], padding[0], padding[1],
strides[0], strides[1]
),
name="y",
)


def relu(x):
print("Python cudnn.py relu!!", file=sys.stderr)
return te.extern(
x.shape,
[x],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cudnn.activation.forward", ins[0], outs[0],
1, 0, 1, 0, 20
),
name="y",
)

def bias_add(data, bias, axis):
print("Python cudnn.py bias_add!!", file=sys.stderr)
return te.extern(
data.shape,
[bias, data],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cudnn.add", ins[0], outs[0],
1, 1
),
name="y",
)



def conv_bias_activation_forward(data, ):
return te.extern(
oshape,
[x, w],

lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cudnn.conv2d+bias+activation.forward",
),
name="y",
)

6 changes: 3 additions & 3 deletions python/tvm/relay/backend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ def lower(sch, inputs, func_name, source_func):
"""
# pylint: disable=broad-except, import-outside-toplevel
import traceback

import logging
try:
f = tvm.driver.lower(sch, inputs, name=func_name)
# logging.debug("lower function %s", func_name)
# logging.debug("%s", _build.lower(sch, inputs, simple_mode=True))
#logging.debug("lower function %s", func_name)
#logging.debug("%s", _build.lower(sch, inputs, simple_mode=True))
except Exception:
msg = traceback.format_exc()
msg += "Error during compile function\n"
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,5 @@
from .scan import *
from .sparse_reshape import *
from .unique import *
from .activation import *
from .bias_add import *
51 changes: 51 additions & 0 deletions python/tvm/topi/cuda/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable, trailing-whitespace
"""Schedule for softmax operator"""
from tvm.target import Target
from tvm import te
from tvm.contrib import cudnn
from .. import generic
from .injective import schedule_injective


def schedule_relu(outs):
"""Schedule for relu op.
Parameters
----------
outs: Array of Tensor
The computation graph description of softmax in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return schedule_injective(outs)


def relu_cudnn(x):
"""Perform softmax on the data using cudnn"""
print("Python topi cuda cudnn relu!!")
return cudnn.relu(x)


def schedule_relu_cudnn(outs):
"""Schedule for softmax cudnn op"""
return generic.schedule_extern(outs)
51 changes: 51 additions & 0 deletions python/tvm/topi/cuda/bias_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable, trailing-whitespace
"""Schedule for softmax operator"""
from tvm.target import Target
from tvm import te
from tvm.contrib import cudnn
from .. import generic
from .injective import schedule_injective


def schedule_bias_add(outs):
"""Schedule for bias_add op.
Parameters
----------
outs: Array of Tensor
The computation graph description of softmax in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return schedule_injective(outs)


def bias_add_cudnn(data, bias, axis=1):
"""Perform bias_add on the data using cudnn"""
print("Python topi cuda cudnn bias_add!!")
return cudnn.bias_add(data, bias, axis)


def schedule_bias_add_cudnn(outs):
"""Schedule for softmax cudnn op"""
return generic.schedule_extern(outs)
78 changes: 78 additions & 0 deletions src/runtime/contrib/cudnn/activation.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/runtime/contrib/cudnn/softmax.cc
* \brief Use external cudnn softmax function
*/
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>

#include "cudnn_utils.h"

namespace tvm {
namespace contrib {

using namespace runtime;

TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.activation.forward")
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* x = args[0];
DLTensor* y = args[1];
double double_alpha = args[2];
double double_beta = args[3];
const void* alpha;
const void* beta;
int mode = args[4];
int nanOpt = args[5];
double coeff = args[6];

CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
entry_ptr->activation_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype);

alpha = CuDNNDataType::GetConst(entry_ptr->activation_entry.data_type, double_alpha);
beta = CuDNNDataType::GetConst(entry_ptr->activation_entry.data_type, double_beta);

// Set Activation
CUDNN_CALL(cudnnSetActivationDescriptor(entry_ptr->activation_entry.activation_desc,
static_cast<cudnnActivationMode_t>(mode),
static_cast<cudnnNanPropagation_t>(nanOpt),
coeff
));


CUDNN_CALL(cudnnSetTensor4dDescriptor(
entry_ptr->activation_entry.shape_desc, CUDNN_TENSOR_NCHW,
entry_ptr->activation_entry.data_type,
static_cast<int>(x->shape[0]), static_cast<int>(x->shape[1]),
static_cast<int>(x->shape[2]), static_cast<int>(x->shape[3])));


CUDNN_CALL(cudnnActivationForward(entry_ptr->handle,
entry_ptr->activation_entry.activation_desc,
alpha,
entry_ptr->activation_entry.shape_desc,
x->data,
beta,
entry_ptr->activation_entry.shape_desc,
y->data));
});

} // namespace contrib
} // namespace tvm
103 changes: 103 additions & 0 deletions src/runtime/contrib/cudnn/add.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/runtime/contrib/cudnn/softmax.cc
* \brief Use external cudnn softmax function
*/
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>

#include "cudnn_utils.h"

namespace tvm {
namespace contrib {

using namespace runtime;

TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.add")
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* x = args[0];
DLTensor* y = args[1];
double double_alpha = args[2];
double double_beta = args[3];
int axis = args[4];

int ndim = x->ndim;
int64_t* shape = x->shape;
if (axis < 0) axis += ndim;
ICHECK(axis >= 0 && axis < ndim);
const void* alpha;
const void* beta;

CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
entry_ptr->bias_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype);

alpha = CuDNNDataType::GetConst(entry_ptr->bias_entry.data_type, double_alpha);
beta = CuDNNDataType::GetConst(entry_ptr->bias_entry.data_type, double_beta);

// Set mode and shape descriptor
if (axis == ndim - 1) {
int64_t N = 1;
for (int i = 0; i < ndim - 1; ++i) {
N *= shape[i];
}
CUDNN_CALL(cudnnSetTensor4dDescriptor(
entry_ptr->bias_entry.shape_desc, CUDNN_TENSOR_NCHW,
entry_ptr->bias_entry.data_type,
static_cast<int>(N),
static_cast<int>(shape[ndim - 1]), 1, 1));
}else{
int64_t pre_axis_dim = 1;
int64_t post_axis_dim = 1;
for (int i = 0; i < ndim; ++i) {
if (i < axis) {
pre_axis_dim *= shape[i];
} else if (i > axis) {
post_axis_dim *= shape[i];
}
}
CUDNN_CALL(cudnnSetTensor4dDescriptor(
entry_ptr->bias_entry.shape_desc, CUDNN_TENSOR_NCHW,
entry_ptr->bias_entry.data_type,
static_cast<int>(pre_axis_dim),
static_cast<int>(shape[axis]), static_cast<int>(post_axis_dim), 1));

}

/*
CUDNN_CALL(cudnnSetTensor4dDescriptor(
entry_ptr->bias_entry.shape_desc, CUDNN_TENSOR_NCHW,
entry_ptr->bias_entry.data_type,
static_cast<int>(x->shape[0]), static_cast<int>(x->shape[1]),
static_cast<int>(x->shape[2]), static_cast<int>(x->shape[3])));
*/


CUDNN_CALL(cudnnAddTensor(entry_ptr->handle,
alpha,
entry_ptr->bias_entry.shape_desc,
x->data,
beta,
entry_ptr->bias_entry.shape_desc,
y->data));
});

} // namespace contrib
} // namespace tvm
Loading

0 comments on commit 15ddf4b

Please sign in to comment.