Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modified code for using ROCm backened within the PyTorch framework #1918

Merged
merged 5 commits into from
Sep 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions mmcv/ops/csrc/common/cuda/carafe_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include "pytorch_cuda_helper.hpp"
#endif

#ifdef HIP_DIFF
#ifdef MMCV_WITH_HIP
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
Expand All @@ -29,7 +29,7 @@ __device__ inline int Loc2Index(const int n, const int c, const int h,
int index = w + (h + (c + n * channel_num) * height) * width;
return index;
}
#ifndef HIP_DIFF
#ifndef MMCV_WITH_HIP
/* TODO: move this to a common place */
template <typename scalar_t>
__device__ inline scalar_t min(scalar_t a, scalar_t b) {
Expand All @@ -44,7 +44,7 @@ __device__ inline scalar_t max(scalar_t a, scalar_t b) {
template <typename scalar_t>
__device__ __forceinline__ scalar_t warpReduceSum(scalar_t val) {
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2)
#ifdef HIP_DIFF
#ifdef MMCV_WITH_HIP
val += __shfl_down(val, offset);
#else
val += __shfl_down_sync(FULL_MASK, val, offset);
Expand All @@ -55,8 +55,8 @@ __device__ __forceinline__ scalar_t warpReduceSum(scalar_t val) {
template <>
__device__ __forceinline__ phalf warpReduceSum(phalf val) {
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2)
#ifdef HIP_DIFF
__PHALF(val) += __shfl_down(FULL_MASK, val, offset);
#ifdef MMCV_WITH_HIP
__PHALF(val) += __shfl_down(val, offset);
#else
__PHALF(val) +=
__shfl_down_sync(FULL_MASK, static_cast<__half>(__PHALF(val)), offset);
Expand Down Expand Up @@ -316,7 +316,7 @@ __global__ void CARAFEBackward_Mask(const int num_kernels,
output_val += top_diff[top_id] * bottom_data[bottom_id];
}
}
#ifdef HIP_DIFF
#ifdef MMCV_WITH_HIP
__syncthreads();
#else
__syncwarp();
Expand Down
4 changes: 4 additions & 0 deletions mmcv/ops/csrc/common/cuda/correlation_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ __global__ void correlation_forward_cuda_kernel(
}
// accumulate
for (int offset = 16; offset > 0; offset /= 2)
#ifdef MMCV_WITH_HIP
prod_sum += __shfl_down(float(prod_sum), offset);
#else
prod_sum += __shfl_down_sync(FULL_MASK, float(prod_sum), offset);
#endif
if (thread == 0) {
output[n][ph][pw][h][w] = prod_sum;
}
Expand Down
4 changes: 2 additions & 2 deletions mmcv/ops/csrc/common/cuda/scatter_points_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ __device__ __forceinline__ static void reduceMax(double *address, double val) {
}

// get rid of meaningless warnings when compiling host code
#ifdef HIP_DIFF
#ifdef MMCV_WITH_HIP
__device__ __forceinline__ static void reduceAdd(float *address, float val) {
atomicAdd(address, val);
}
Expand Down Expand Up @@ -86,7 +86,7 @@ __device__ __forceinline__ static void reduceAdd(double *address, double val) {
#endif
}
#endif // __CUDA_ARCH__
#endif // HIP_DIFF
#endif // MMCV_WITH_HIP

template <typename T>
__global__ void feats_reduce_kernel(
Expand Down
2 changes: 1 addition & 1 deletion mmcv/ops/csrc/common/utils/spconv/tensorview/tensorview.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

namespace tv {

#ifdef __NVCC__
#if defined(__NVCC__) || defined(__HIP__)
#define TV_HOST_DEVICE_INLINE __forceinline__ __device__ __host__
#define TV_DEVICE_INLINE __forceinline__ __device__
#define TV_HOST_DEVICE __device__ __host__
Expand Down
15 changes: 12 additions & 3 deletions mmcv/ops/csrc/parrots/info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,22 @@
#include "pytorch_cpp_helper.hpp"

#ifdef MMCV_WITH_CUDA
#ifndef HIP_DIFF
#ifdef MMCV_WITH_HIP
#include <hip/hip_runtime_api.h>
int get_hiprt_version() {
int runtimeVersion;
hipRuntimeGetVersion(&runtimeVersion);
return runtimeVersion;
}
#else
#include <cuda_runtime_api.h>
int get_cudart_version() { return CUDART_VERSION; }
#endif
#endif

std::string get_compiling_cuda_version() {
#ifdef MMCV_WITH_CUDA
#ifndef HIP_DIFF
#ifndef MMCV_WITH_HIP
std::ostringstream oss;
// copied from
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231
Expand All @@ -25,7 +32,9 @@ std::string get_compiling_cuda_version() {
printCudaStyleVersion(get_cudart_version());
return oss.str();
#else
return std::string("rocm not available");
std::ostringstream oss;
oss << get_hiprt_version();
return oss.str();
#endif
#else
return std::string("not available");
Expand Down
5 changes: 3 additions & 2 deletions mmcv/ops/csrc/pytorch/cuda/bbox_overlaps_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include "pytorch_cuda_helper.hpp"

// Disable fp16 on ROCm device
#ifndef HIP_DIFF
#ifndef MMCV_WITH_HIP
#if __CUDA_ARCH__ >= 530
template <>
__global__ void bbox_overlaps_cuda_kernel<at::Half>(
Expand All @@ -15,8 +15,9 @@ __global__ void bbox_overlaps_cuda_kernel<at::Half>(
reinterpret_cast<__half*>(ious), num_bbox1,
num_bbox2, mode, aligned, offset);
}

#endif // __CUDA_ARCH__ >= 530
#endif // HIP_DIFF
#endif // MMCV_WITH_HIP

void BBoxOverlapsCUDAKernelLauncher(const Tensor bboxes1, const Tensor bboxes2,
Tensor ious, const int mode,
Expand Down
2 changes: 1 addition & 1 deletion mmcv/ops/csrc/pytorch/cuda/fused_spconv_ops_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include <cuda_runtime_api.h>
#include <torch/script.h>
#include "../spconv_utils.h"
#include <utils/spconv/spconv/indice.h>
#include <utils/spconv/spconv/reordering.h>

#include "../spconv_utils.h"
#include "pytorch_cuda_helper.hpp"

torch::Tensor FusedIndiceConvBatchnormCUDAKernelLauncher(
Expand Down
2 changes: 1 addition & 1 deletion mmcv/ops/csrc/pytorch/cuda/sparse_indice.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include <ATen/ATen.h>
#include "../spconv_utils.h"
#include <utils/spconv/spconv/indice.h>
#include <utils/spconv/spconv/mp_helper.h>
#include <utils/spconv/tensorview/helper_launch.h>
Expand All @@ -23,7 +24,6 @@
#include <spconv/indice.cuh>
#include <type_traits>

#include "../spconv_utils.h"
#include "pytorch_cuda_helper.hpp"

namespace functor {
Expand Down
2 changes: 1 addition & 1 deletion mmcv/ops/csrc/pytorch/cuda/sparse_maxpool.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include <ATen/ATen.h>
#include "../spconv_utils.h"
#include <utils/spconv/spconv/maxpool.h>
#include <utils/spconv/spconv/mp_helper.h>
#include <utils/spconv/tensorview/helper_launch.h>
Expand All @@ -23,7 +24,6 @@
#include <type_traits>
#include <utils/spconv/tensorview/helper_kernel.cuh>

#include "../spconv_utils.h"
#include "pytorch_cuda_helper.hpp"

template <typename scalar_t, typename Index, int NumTLP, int NumILP>
Expand Down
2 changes: 1 addition & 1 deletion mmcv/ops/csrc/pytorch/cuda/sparse_pool_ops_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include <cuda_runtime_api.h>
#include <torch/script.h>
#include "../spconv_utils.h"
#include <utils/spconv/spconv/maxpool.h>

#include "../spconv_utils.h"
#include "pytorch_cuda_helper.hpp"

torch::Tensor IndiceMaxpoolForwardCUDAKernelLauncher(torch::Tensor features,
Expand Down
2 changes: 1 addition & 1 deletion mmcv/ops/csrc/pytorch/cuda/sparse_reordering.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include <ATen/ATen.h>
#include "../spconv_utils.h"
#include <utils/spconv/spconv/mp_helper.h>
#include <utils/spconv/spconv/reordering.h>
#include <utils/spconv/tensorview/helper_launch.h>
Expand All @@ -24,7 +25,6 @@
#include <type_traits>
#include <utils/spconv/tensorview/helper_kernel.cuh>

#include "../spconv_utils.h"
#include "pytorch_cuda_helper.hpp"

namespace functor {
Expand Down
2 changes: 1 addition & 1 deletion mmcv/ops/csrc/pytorch/cuda/spconv_ops_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include <cuda_runtime_api.h>
#include <torch/script.h>
#include "../spconv_utils.h"
#include <utils/spconv/spconv/indice.h>
#include <utils/spconv/spconv/reordering.h>

#include "../spconv_utils.h"
#include "pytorch_cuda_helper.hpp"

template <unsigned NDim>
Expand Down
15 changes: 12 additions & 3 deletions mmcv/ops/csrc/pytorch/info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,22 @@
#include "pytorch_cpp_helper.hpp"

#ifdef MMCV_WITH_CUDA
#ifndef HIP_DIFF
#ifdef MMCV_WITH_HIP
#include <hip/hip_runtime_api.h>
int get_hiprt_version() {
int runtimeVersion;
hipRuntimeGetVersion(&runtimeVersion);
return runtimeVersion;
}
#else
#include <cuda_runtime_api.h>
int get_cudart_version() { return CUDART_VERSION; }
#endif
#endif

std::string get_compiling_cuda_version() {
#ifdef MMCV_WITH_CUDA
#ifndef HIP_DIFF
#ifndef MMCV_WITH_HIP
std::ostringstream oss;
// copied from
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231
Expand All @@ -25,7 +32,9 @@ std::string get_compiling_cuda_version() {
printCudaStyleVersion(get_cudart_version());
return oss.str();
#else
return std::string("rocm not available");
std::ostringstream oss;
oss << get_hiprt_version();
return oss.str();
#endif
#else
return std::string("not available");
Expand Down
30 changes: 21 additions & 9 deletions mmcv/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,27 @@ def collect_env():
env_info['CUDA_HOME'] = CUDA_HOME

if CUDA_HOME is not None and osp.isdir(CUDA_HOME):
try:
nvcc = osp.join(CUDA_HOME, 'bin/nvcc')
nvcc = subprocess.check_output(f'"{nvcc}" -V', shell=True)
nvcc = nvcc.decode('utf-8').strip()
release = nvcc.rfind('Cuda compilation tools')
build = nvcc.rfind('Build ')
nvcc = nvcc[release:build].strip()
except subprocess.SubprocessError:
nvcc = 'Not Available'
if CUDA_HOME == '/opt/rocm':
try:
nvcc = osp.join(CUDA_HOME, 'hip/bin/hipcc')
nvcc = subprocess.check_output(
f'"{nvcc}" --version', shell=True)
nvcc = nvcc.decode('utf-8').strip()
release = nvcc.rfind('HIP version:')
build = nvcc.rfind('')
nvcc = nvcc[release:build].strip()
except subprocess.SubprocessError:
nvcc = 'Not Available'
else:
try:
nvcc = osp.join(CUDA_HOME, 'bin/nvcc')
nvcc = subprocess.check_output(f'"{nvcc}" -V', shell=True)
nvcc = nvcc.decode('utf-8').strip()
release = nvcc.rfind('Cuda compilation tools')
build = nvcc.rfind('Build ')
nvcc = nvcc[release:build].strip()
except subprocess.SubprocessError:
nvcc = 'Not Available'
env_info['NVCC'] = nvcc

try:
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def get_extensions():
if is_rocm_pytorch or torch.cuda.is_available() or os.getenv(
'FORCE_CUDA', '0') == '1':
if is_rocm_pytorch:
define_macros += [('HIP_DIFF', None)]
define_macros += [('MMCV_WITH_HIP', None)]
define_macros += [('MMCV_WITH_CUDA', None)]
cuda_args = os.getenv('MMCV_CUDA_ARGS')
extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
Expand All @@ -289,6 +289,7 @@ def get_extensions():
glob.glob('./mmcv/ops/csrc/pytorch/cuda/*.cu') + \
glob.glob('./mmcv/ops/csrc/pytorch/cuda/*.cpp')
extension = CUDAExtension
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/pytorch'))
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/cuda'))
elif (hasattr(torch, 'is_mlu_available') and
Expand Down