diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index b363156113..5af7fe6f27 100644 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -12,6 +12,7 @@ from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss, sigmoid_focal_loss, softmax_focal_loss) +from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu from .info import (get_compiler_version, get_compiling_cuda_version, get_onnxruntime_op_path) from .masked_conv import MaskedConv2d, masked_conv2d @@ -27,6 +28,7 @@ from .saconv import SAConv2d from .sync_bn import SyncBatchNorm from .tin_shift import TINShift, tin_shift +from .upfirdn2d import upfirdn2d __all__ = [ 'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe', @@ -41,5 +43,6 @@ 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d', 'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask', 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign', - 'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated' + 'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated', + 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu' ] diff --git a/mmcv/ops/csrc/pytorch/fused_bias_leakyrelu.cpp b/mmcv/ops/csrc/pytorch/fused_bias_leakyrelu.cpp new file mode 100644 index 0000000000..45637ad4f8 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/fused_bias_leakyrelu.cpp @@ -0,0 +1,26 @@ +// Modified from +// from +// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +torch::Tensor fused_bias_leakyrelu_op(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, int act, + int grad, float alpha, float scale); + +#endif + +torch::Tensor fused_bias_leakyrelu(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, int act, + int grad, float alpha, float scale) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_leakyrelu_op(input, bias, refer, act, grad, alpha, scale); +#else + AT_ERROR("Fused bias leakyrelu is not compiled with GPU support"); +#endif +} diff --git a/mmcv/ops/csrc/pytorch/fused_bias_leakyrelu_cuda.cu b/mmcv/ops/csrc/pytorch/fused_bias_leakyrelu_cuda.cu new file mode 100644 index 0000000000..cde947de48 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/fused_bias_leakyrelu_cuda.cu @@ -0,0 +1,109 @@ +// from +// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include +#include +#include +#include +#include +#include + +#include + +template +static __global__ void fused_bias_act_kernel( + scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, + const scalar_t* p_ref, int act, int grad, scalar_t alpha, scalar_t scale, + int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { + int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; + + scalar_t zero = 0.0; + + for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; + loop_idx++, xi += blockDim.x) { + scalar_t x = p_x[xi]; + + if (use_bias) { + x += p_b[(xi / step_b) % size_b]; + } + + scalar_t ref = use_ref ? p_ref[xi] : zero; + + scalar_t y; + + // act = 1: linear layer + // act = 3: leaky relu layer + // grad = 0: direct forward path + // grad = 1: first order deviation + // grad = 2: second order deviation + switch (act * 10 + grad) { + default: + case 10: + y = x; + break; + case 11: + y = x; + break; + case 12: + y = 0.0; + break; + + case 30: + y = (x > 0.0) ? x : x * alpha; + break; + case 31: + y = (ref > 0.0) ? x : x * alpha; + break; + case 32: + y = 0.0; + break; + } + + out[xi] = y * scale; + } +} + +torch::Tensor fused_bias_leakyrelu_op(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, int act, + int grad, float alpha, float scale) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + auto x = input.contiguous(); + auto b = bias.contiguous(); + auto ref = refer.contiguous(); + + int use_bias = b.numel() ? 1 : 0; + int use_ref = ref.numel() ? 1 : 0; + + int size_x = x.numel(); + int size_b = b.numel(); + int step_b = 1; + + for (int i = 1 + 1; i < x.dim(); i++) { + step_b *= x.size(i); + } + + int loop_x = 4; + int block_size = 4 * 32; + int grid_size = (size_x - 1) / (loop_x * block_size) + 1; + + auto y = torch::empty_like(x); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + x.scalar_type(), "fused_bias_act_kernel", [&] { + fused_bias_act_kernel<<>>( + y.data_ptr(), x.data_ptr(), + b.data_ptr(), ref.data_ptr(), act, grad, alpha, + scale, loop_x, size_x, step_b, size_b, use_bias, use_ref); + }); + + return y; +} diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 70d5cebbdd..bb8d415fd5 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -182,7 +182,18 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order, const Tensor dets_sorted, const float iou_threshold, const int multi_label); +Tensor upfirdn2d(const Tensor& input, const Tensor& kernel, int up_x, int up_y, + int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0, + int pad_y1); + +Tensor fused_bias_leakyrelu(const Tensor& input, const Tensor& bias, + const Tensor& refer, int act, int grad, float alpha, + float scale); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); + m.def("fused_bias_leakyrelu", &fused_bias_leakyrelu, + "fused_bias_leakyrelu (CUDA)"); m.def("get_compiler_version", &get_compiler_version, "get_compiler_version"); m.def("get_compiling_cuda_version", &get_compiling_cuda_version, "get_compiling_cuda_version"); diff --git a/mmcv/ops/csrc/pytorch/upfirdn2d.cpp b/mmcv/ops/csrc/pytorch/upfirdn2d.cpp new file mode 100644 index 0000000000..44322a1210 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/upfirdn2d.cpp @@ -0,0 +1,25 @@ +// from +// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +torch::Tensor upfirdn2d_op(const torch::Tensor& input, + const torch::Tensor& kernel, int up_x, int up_y, + int down_x, int down_y, int pad_x0, int pad_x1, + int pad_y0, int pad_y1); + +#endif + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, int pad_x0, + int pad_x1, int pad_y0, int pad_y1) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, + pad_y0, pad_y1); +#else + AT_ERROR("UpFirDn2d is not compiled with GPU support"); +#endif +} diff --git a/mmcv/ops/csrc/pytorch/upfirdn2d_kernel.cu b/mmcv/ops/csrc/pytorch/upfirdn2d_kernel.cu new file mode 100644 index 0000000000..52a175bfd2 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/upfirdn2d_kernel.cu @@ -0,0 +1,370 @@ +// from +// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include +#include +#include +#include +#include +#include + +#include + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + +template +__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; + int out_y = minor_idx / p.minor_dim; + minor_idx -= out_y * p.minor_dim; + int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; + int major_idx_base = blockIdx.z * p.loop_major; + + if (out_x_base >= p.out_w || out_y >= p.out_h || + major_idx_base >= p.major_dim) { + return; + } + + int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; + int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); + int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; + int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major && major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, out_x = out_x_base; + loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { + int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; + int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); + int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; + int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; + + const scalar_t *x_p = + &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + + minor_idx]; + const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; + int x_px = p.minor_dim; + int k_px = -p.up_x; + int x_py = p.in_w * p.minor_dim; + int k_py = -p.up_y * p.kernel_w; + + scalar_t v = 0.0f; + + for (int y = 0; y < h; y++) { + for (int x = 0; x < w; x++) { + v += static_cast(*x_p) * static_cast(*k_p); + x_p += x_px; + k_p += k_px; + } + + x_p += x_py - w * x_px; + k_p += k_py - w * k_px; + } + + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } +} + +template +__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | + major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; + tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major & major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; + loop_x < p.loop_x & tile_out_x < p.out_w; + loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; + in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * + p.minor_dim + + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; + out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + +#pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) +#pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * + sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } + } + } +} + +torch::Tensor upfirdn2d_op(const torch::Tensor &input, + const torch::Tensor &kernel, int up_x, int up_y, + int down_x, int down_y, int pad_x0, int pad_x1, + int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / + p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / + p.down_x; + + auto out = + at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h = -1; + int tile_out_w = -1; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w > 0) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } else { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 4; + block_size = dim3(4, 32, 1); + grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, + (p.out_w - 1) / (p.loop_x * block_size.y) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 2: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 3: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 4: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 5: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 6: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + default: + upfirdn2d_kernel_large<<>>( + out.data_ptr(), x.data_ptr(), + k.data_ptr(), p); + } + }); + + return out; +} diff --git a/mmcv/ops/fused_bias_leakyrelu.py b/mmcv/ops/fused_bias_leakyrelu.py new file mode 100644 index 0000000000..eefaf74da1 --- /dev/null +++ b/mmcv/ops/fused_bias_leakyrelu.py @@ -0,0 +1,156 @@ +# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 + +import torch +import torch.nn.functional as F +from torch import nn +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', ['fused_bias_leakyrelu']) + + +class FusedBiasLeakyReLUFunctionBackward(Function): + """Calculate second order deviation. + + This function is to compute the second order deviation for the fused leaky + relu operation. + """ + + @staticmethod + def forward(ctx, grad_output, out, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = ext_module.fused_bias_leakyrelu(grad_output, empty, out, + 3, 1, negative_slope, + scale) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + grad_bias = grad_input.sum(dim).detach() + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + out, = ctx.saved_tensors + + # The second order deviation, in fact, contains two parts, while the + # the first part is zero. Thus, we direct consider the second part + # which is similar with the first order deviation in implementation. + gradgrad_out = ext_module.fused_bias_leakyrelu(gradgrad_input, + gradgrad_bias, out, 3, + 1, ctx.negative_slope, + ctx.scale) + + return gradgrad_out, None, None, None + + +class FusedBiasLeakyReLUFunction(Function): + + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + out = ext_module.fused_bias_leakyrelu(input, bias, empty, 3, 0, + negative_slope, scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + + grad_input, grad_bias = FusedBiasLeakyReLUFunctionBackward.apply( + grad_output, out, ctx.negative_slope, ctx.scale) + + return grad_input, grad_bias, None, None + + +class FusedBiasLeakyReLU(nn.Module): + """Fused bias leaky ReLU. + + This function is introduced in the StyleGAN2: + http://arxiv.org/abs/1912.04958 + + The bias term comes from the convolution operation. In addition, to keep + the variance of the feature map or gradients unchanged, they also adopt a + scale similarly with Kaiming initalization. However, since the + :math:`1 + \alpha^2` : is too small, we can just ignore it. Therefore, the + final sacle is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501 + your own scale. + + TODO: Implement the CPU version. + + Args: + channel (int): The channnel number of the feature map. + negative_slope (float, optional): Same as nn.LeakyRelu. + Defaults to 0.2. + scale (float, optional): A scalar to adjust the variance of the feature + map. Defaults to 2**0.5. + """ + + def __init__(self, num_channels, negative_slope=0.2, scale=2**0.5): + super(FusedBiasLeakyReLU, self).__init__() + + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_bias_leakyrelu(input, self.bias, self.negative_slope, + self.scale) + + +def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5): + """Fused bias leaky ReLU function. + + This function is introduced in the StyleGAN2: + http://arxiv.org/abs/1912.04958 + + The bias term comes from the convolution operation. In addition, to keep + the variance of the feature map or gradients unchanged, they also adopt a + scale similarly with Kaiming initalization. However, since the + :math:`1 + \alpha^2` : is too small, we can just ignore it. Therefore, the + final sacle is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501 + your own scale. + + Args: + input (torch.Tensor): Input feature map. + bias (nn.Parameter): The bias from convolution operation. + negative_slope (float, optional): Same as nn.LeakyRelu. + Defaults to 0.2. + scale (float, optional): A scalar to adjust the variance of the feature + map. Defaults to 2**0.5. + + Returns: + torch.Tensor: Feature map after non-linear activation. + """ + + if not input.is_cuda: + return bias_leakyrelu_ref(input, bias, negative_slope, scale) + + return FusedBiasLeakyReLUFunction.apply(input, bias, negative_slope, scale) + + +def bias_leakyrelu_ref(x, bias, negative_slope=0.2, scale=2**0.5): + + if bias is not None: + assert bias.ndim == 1 + assert bias.shape[0] == x.shape[1] + x = x + bias.reshape([-1 if i == 1 else 1 for i in range(x.ndim)]) + + x = F.leaky_relu(x, negative_slope) + if scale != 1: + x = x * scale + + return x diff --git a/mmcv/ops/upfirdn2d.py b/mmcv/ops/upfirdn2d.py new file mode 100644 index 0000000000..c627edffed --- /dev/null +++ b/mmcv/ops/upfirdn2d.py @@ -0,0 +1,208 @@ +# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501 + +import torch +from torch.autograd import Function +from torch.nn import functional as F + +from ..utils import ext_loader + +upfirdn2d_ext = ext_loader.load_ext('_ext', ['upfirdn2d']) + + +class UpFirDn2dBackward(Function): + + @staticmethod + def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, + in_size, out_size): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_ext.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], + in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + kernel, = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], + ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_ext.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], + # ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], + ctx.out_size[0], ctx.out_size[1]) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, + down_y, pad_x0, pad_x1, pad_y0, pad_y1) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + """UpFRIDn for 2d features. + + UpFIRDn is short for upsample, apply FIR filter and downsample. More + details can be found in: + https://www.mathworks.com/help/signal/ref/upfirdn.html + + Args: + input (Tensor): Tensor with shape of (n, c, h, w). + kernel (Tensor): Filter kernel. + up (int, optional): Upsampling factor. Defaults to 1. + down (int, optional): Downsampling factor. Defaults to 1. + pad (tuple[int], optional): Padding for tensors, (x_pad, y_pad). + Defaults to (0, 0). + + Returns: + Tensor: Tensor after UpFIRDn. + """ + if input.device.type == 'cpu': + out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], + pad[1], pad[0], pad[1]) + else: + out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), + (pad[0], pad[1], pad[0], pad[1])) + + return out + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, + pad_y0, pad_y1): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad( + out, + [0, 0, + max(pad_x0, 0), + max(pad_x1, 0), + max(pad_y0, 0), + max(pad_y1, 0)]) + out = out[:, + max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/tests/test_ops/test_fused_bias_leakyrelu.py b/tests/test_ops/test_fused_bias_leakyrelu.py new file mode 100644 index 0000000000..ae237f357a --- /dev/null +++ b/tests/test_ops/test_fused_bias_leakyrelu.py @@ -0,0 +1,33 @@ +import pytest +import torch +from torch.autograd import gradcheck, gradgradcheck + + +class TestFusedBiasLeakyReLU(object): + + @classmethod + def setup_class(cls): + if not torch.cuda.is_available(): + return + cls.input_tensor = torch.randn((2, 2, 2, 2), requires_grad=True).cuda() + cls.bias = torch.zeros(2, requires_grad=True).cuda() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') + def test_gradient(self): + + from mmcv.ops import FusedBiasLeakyReLU + gradcheck( + FusedBiasLeakyReLU(2).cuda(), + self.input_tensor, + eps=1e-4, + atol=1e-3) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') + def test_gradgradient(self): + + from mmcv.ops import FusedBiasLeakyReLU + gradgradcheck( + FusedBiasLeakyReLU(2).cuda(), + self.input_tensor, + eps=1e-4, + atol=1e-3) diff --git a/tests/test_ops/test_upfirdn2d.py b/tests/test_ops/test_upfirdn2d.py new file mode 100644 index 0000000000..59e8e69e42 --- /dev/null +++ b/tests/test_ops/test_upfirdn2d.py @@ -0,0 +1,41 @@ +import pytest +import torch +from torch.autograd import gradcheck, gradgradcheck + + +class TestUpFirDn2d(object): + """Unit test for UpFirDn2d. + + Here, we just test the basic case of upsample version. More gerneal tests + will be included in other unit test for UpFirDnUpsample and + UpFirDnDownSample modules. + """ + + @classmethod + def setup_class(cls): + kernel_1d = torch.tensor([1., 3., 3., 1.]) + cls.kernel = kernel_1d[:, None] * kernel_1d[None, :] + cls.kernel = cls.kernel / cls.kernel.sum() + cls.factor = 2 + pad = cls.kernel.shape[0] - cls.factor + cls.pad = ((pad + 1) // 2 + cls.factor - 1, pad // 2) + + cls.input_tensor = torch.randn((2, 3, 4, 4), requires_grad=True) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') + def test_upfirdn2d(self): + from mmcv.ops import upfirdn2d + + gradcheck( + upfirdn2d, + (self.input_tensor.cuda(), self.kernel.type_as( + self.input_tensor).cuda(), self.factor, 1, self.pad), + eps=1e-4, + atol=1e-3) + + gradgradcheck( + upfirdn2d, + (self.input_tensor.cuda(), self.kernel.type_as( + self.input_tensor).cuda(), self.factor, 1, self.pad), + eps=1e-4, + atol=1e-3)