Skip to content

Commit

Permalink
[Feature] Add cuda ops: UpFirDn2d and fused_bias_leakyrelu (#900)
Browse files Browse the repository at this point in the history
* add upfirdn2d op

* fix bug in pybind

* add fused bias leakyrelu

* fix bug in fused-bias-leakyrelu

* fix lint error

* fix bug in build cpu version

* fix bug in build cpu version

* fix lint

* fix comment from zww

Co-authored-by: zhangshilong <zhangshilong@sensetime.com>
  • Loading branch information
nbei and zhangshilong committed Mar 21, 2021
1 parent 371a217 commit 933b052
Show file tree
Hide file tree
Showing 10 changed files with 983 additions and 1 deletion.
5 changes: 4 additions & 1 deletion mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d
from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss,
sigmoid_focal_loss, softmax_focal_loss)
from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu
from .info import (get_compiler_version, get_compiling_cuda_version,
get_onnxruntime_op_path)
from .masked_conv import MaskedConv2d, masked_conv2d
Expand All @@ -27,6 +28,7 @@
from .saconv import SAConv2d
from .sync_bn import SyncBatchNorm
from .tin_shift import TINShift, tin_shift
from .upfirdn2d import upfirdn2d

__all__ = [
'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe',
Expand All @@ -41,5 +43,6 @@
'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated'
'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated',
'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu'
]
26 changes: 26 additions & 0 deletions mmcv/ops/csrc/pytorch/fused_bias_leakyrelu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Modified from
// from
// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp
#include "pytorch_cpp_helper.hpp"

#ifdef MMCV_WITH_CUDA
torch::Tensor fused_bias_leakyrelu_op(const torch::Tensor& input,
const torch::Tensor& bias,
const torch::Tensor& refer, int act,
int grad, float alpha, float scale);

#endif

torch::Tensor fused_bias_leakyrelu(const torch::Tensor& input,
const torch::Tensor& bias,
const torch::Tensor& refer, int act,
int grad, float alpha, float scale) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA(input);
CHECK_CUDA(bias);

return fused_bias_leakyrelu_op(input, bias, refer, act, grad, alpha, scale);
#else
AT_ERROR("Fused bias leakyrelu is not compiled with GPU support");
#endif
}
109 changes: 109 additions & 0 deletions mmcv/ops/csrc/pytorch/fused_bias_leakyrelu_cuda.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// from
// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html

#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/types.h>

#include <ATen/cuda/CUDAApplyUtils.cuh>

template <typename scalar_t>
static __global__ void fused_bias_act_kernel(
scalar_t* out, const scalar_t* p_x, const scalar_t* p_b,
const scalar_t* p_ref, int act, int grad, scalar_t alpha, scalar_t scale,
int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;

scalar_t zero = 0.0;

for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
loop_idx++, xi += blockDim.x) {
scalar_t x = p_x[xi];

if (use_bias) {
x += p_b[(xi / step_b) % size_b];
}

scalar_t ref = use_ref ? p_ref[xi] : zero;

scalar_t y;

// act = 1: linear layer
// act = 3: leaky relu layer
// grad = 0: direct forward path
// grad = 1: first order deviation
// grad = 2: second order deviation
switch (act * 10 + grad) {
default:
case 10:
y = x;
break;
case 11:
y = x;
break;
case 12:
y = 0.0;
break;

case 30:
y = (x > 0.0) ? x : x * alpha;
break;
case 31:
y = (ref > 0.0) ? x : x * alpha;
break;
case 32:
y = 0.0;
break;
}

out[xi] = y * scale;
}
}

torch::Tensor fused_bias_leakyrelu_op(const torch::Tensor& input,
const torch::Tensor& bias,
const torch::Tensor& refer, int act,
int grad, float alpha, float scale) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);

auto x = input.contiguous();
auto b = bias.contiguous();
auto ref = refer.contiguous();

int use_bias = b.numel() ? 1 : 0;
int use_ref = ref.numel() ? 1 : 0;

int size_x = x.numel();
int size_b = b.numel();
int step_b = 1;

for (int i = 1 + 1; i < x.dim(); i++) {
step_b *= x.size(i);
}

int loop_x = 4;
int block_size = 4 * 32;
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;

auto y = torch::empty_like(x);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x.scalar_type(), "fused_bias_act_kernel", [&] {
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha,
scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
});

return y;
}
11 changes: 11 additions & 0 deletions mmcv/ops/csrc/pytorch/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,18 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
const Tensor dets_sorted, const float iou_threshold,
const int multi_label);

Tensor upfirdn2d(const Tensor& input, const Tensor& kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0,
int pad_y1);

Tensor fused_bias_leakyrelu(const Tensor& input, const Tensor& bias,
const Tensor& refer, int act, int grad, float alpha,
float scale);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
m.def("fused_bias_leakyrelu", &fused_bias_leakyrelu,
"fused_bias_leakyrelu (CUDA)");
m.def("get_compiler_version", &get_compiler_version, "get_compiler_version");
m.def("get_compiling_cuda_version", &get_compiling_cuda_version,
"get_compiling_cuda_version");
Expand Down
25 changes: 25 additions & 0 deletions mmcv/ops/csrc/pytorch/upfirdn2d.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// from
// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp
#include "pytorch_cpp_helper.hpp"

#ifdef MMCV_WITH_CUDA
torch::Tensor upfirdn2d_op(const torch::Tensor& input,
const torch::Tensor& kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1,
int pad_y0, int pad_y1);

#endif

torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y, int pad_x0,
int pad_x1, int pad_y0, int pad_y1) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA(input);
CHECK_CUDA(kernel);

return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
pad_y0, pad_y1);
#else
AT_ERROR("UpFirDn2d is not compiled with GPU support");
#endif
}
Loading

0 comments on commit 933b052

Please sign in to comment.