Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modified code for using ROCm backened within the PyTorch framework #1918

Merged
merged 5 commits into from
Sep 13, 2022
Merged

modified code for using ROCm backened within the PyTorch framework #1918

merged 5 commits into from
Sep 13, 2022

Conversation

zstreet87
Copy link
Contributor

@zstreet87 zstreet87 commented Apr 26, 2022

Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.

Motivation

Related issue: #1898

This PR enables HIP as a backened for pytorch only. It utilizes hipify to generate the HIP source and include folders via JIT compilation within ROCm pytorch. I tested these changes with multiple docker containers, e.g., 'rocm/pytorch:rocm5.0_ubuntu18.04_py3.7_pytorch_1.10.0'.

Modification

  1. Changed 'HIP_DIFF' to 'MMCV_WITH_HIP' for style consistency.
  2. Added './mmcv/ops/csrc/pytorch' to the include directory as hipify in pytorch needed to hipify 'spconv_utils.h'
  3. rocblas has ambiguities for 'is_complex' ultimately from the order of the include-files in some files. The workaround in this PR is to add '#include spconv_utils.h' before the other file to break this ambiguity. This may go against the style guideline, unfortunately. Albeit, this issue has been reported and will be addressed with future versions of ROCm.
  4. Added 'defined(HIP)' to tensorview.h to allow for host function calls in global functions in hip.
  5. Added __shfl_down without mask in ./mmcv/ops/common/cuda/correlation_cuda.cuh since HIP does not have warp level granularity yet. The test in ./tests/test_ops/test_correlation.py passed.

BC-breaking (Optional)

N/A

Use cases (Optional)

Using MMCV with the pytorch framework on AMD GPUs.

Checklist

Before PR:

  • I have read and followed the workflow indicated in the CONTRIBUTING.md to create this PR.
  • Pre-commit or linting tools indicated in CONTRIBUTING.md are used to fix the potential lint issues.
  • Bug fixes are covered by unit tests, the case that causes the bug should be added in the unit tests.
  • New functionalities are covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  • The documentation has been modified accordingly, including docstring or example tutorials.

After PR:

  • If the modification has potential influence on downstream or other related projects, this PR should be tested with some of those projects, like MMDet or MMCls.
  • CLA has been signed and all committers have signed the CLA in this PR.

@@ -4,15 +4,22 @@
#include "pytorch_cpp_helper.hpp"

#ifdef MMCV_WITH_CUDA
#ifndef HIP_DIFF
#ifdef MMCV_WITH_HIP
#include <cuda_runtime_api.h>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we include hip runtime here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I thought about that when I added this code. So it's fine because hipify will change this header to the hip runtime header. Though, since it's just below "#if MMCV_WITH_HIP" having "#include <hip/hip_runtime_api.h> explicitly is less confusing. I'll make that change now!

@grimoire
Copy link
Member

Can you pass all the unit test?

pytest tests/test_ops

tests/test_ops/test_convex_iou.py, tests/test_ops/test_roiaware_pool3d.py, tests/test_ops/test_voxelization.py would give Fatal Python error: Aborted in the test.
I am not sure if this has anything to do with your PR since I haven't tested it for a long time...

@zstreet87
Copy link
Contributor Author

It looks like those tests ran fine on my end. I'm only seeing quite a few warnings like the following:

DeprecationWarning: "out_size" is deprecated in `RoIAlignRotated.__init__`, please use "output_size" instead
    'instead', DeprecationWarning)

Here is my summary line running pytest against test/test_ops

131 passed, 35 skipped, 94 warnings in 55.46s

Note the particular tests you mentioned seem to be fine...

@grimoire
Copy link
Member

Sure, I see. Maybe there are something wrong with my env. I will let someone else have a try.

@teamwong111
Copy link
Contributor

My env is gfx906, ROCm 4.0.0, PyTorch 1.8.0. pytest tests/test_ops/ gives me Aborted (core dumped).

@zstreet87
Copy link
Contributor Author

zstreet87 commented May 9, 2022

My env is gfx906, ROCm 4.0.0, PyTorch 1.8.0. pytest tests/test_ops/ gives me Aborted (core dumped).

Can you inspect the core dumped file? Also, can you attempt to use a newer version of ROCm in a container? Hopefully it's not a driver issue...

I see #1704 used this version of ROCm and this work was merged into main. However, this code wouldn't compile on the now supported / newer version of ROCm so that's what this PR is attempting to address. May have a versioning requirement which is not ideal, unfortunately.

@zstreet87
Copy link
Contributor Author

Hi there,

Why is the unit tests on hold ?

@zstreet87 zstreet87 requested a review from grimoire July 26, 2022 22:56
@grimoire
Copy link
Member

Sorry, my old ROCm environment is not available anymore. I will find a new one and continue this review ASAP.

@zstreet87
Copy link
Contributor Author

Sorry, my old ROCm environment is not available anymore. I will find a new one and continue this review ASAP.

Looks like green checkmark :)

Copy link
Member

@grimoire grimoire left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still meet the same error. But as long as the Rocm support is broken now, we can merge this and fix the random coredump in the future.

@zstreet87
Copy link
Contributor Author

Still meet the same error. But as long as the Rocm support is broken now, we can merge this and fix the random coredump in the future.

Great! Look forward to the merge 👍

@OpenMMLab-Assistant-004
Copy link

Hi @zstreet87 !First of all, we want to express our gratitude for your significant PR in the mmcv project. Your contribution is highly appreciated, and we are grateful for your efforts in helping improve this open-source project during your personal time. We believe that many developers will benefit from your PR.

We would also like to invite you to join our Special Interest Group (SIG) private channel on Discord, where you can share your experiences, ideas, and build connections with like-minded peers. To join the SIG channel, simply message moderator— OpenMMLab on Discord or briefly share your open-source contributions in the #introductions channel and we will assist you. Look forward to seeing you there! Join us :https://discord.gg/raweFPmdzG

If you have WeChat,welcome to join our community on WeChat. You can add our assistant :openmmlabwx. Please add "mmsig + Github ID" as a remark when adding friends:)
Thank you again for your contribution❤

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants