Skip to content

Commit

Permalink
Make PyTorch3D C++17 incompatible again :(
Browse files Browse the repository at this point in the history
Summary: D38919607 (c4545a7) and D38858887 (06cbba2) were premature, turns out CUDA 10.2 doesn't support C++17.

Reviewed By: bottler

Differential Revision: D39156205

fbshipit-source-id: 5e2e84cc4a57d1113a915166631651d438540d56
  • Loading branch information
Krzysztof Chalupka authored and facebook-github-bot committed Aug 31, 2022
1 parent 1530a66 commit 9577198
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 8 deletions.
8 changes: 4 additions & 4 deletions pytorch3d/csrc/iou_box3d/iou_box3d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ __global__ void IoUBox3DKernel(
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = gridDim.x * blockDim.x;

std::array<FaceVerts, NUM_TRIS> box1_tris{};
std::array<FaceVerts, NUM_TRIS> box2_tris{};
std::array<FaceVerts, NUM_PLANES> box1_planes{};
std::array<FaceVerts, NUM_PLANES> box2_planes{};
FaceVerts box1_tris[NUM_TRIS];
FaceVerts box2_tris[NUM_TRIS];
FaceVerts box1_planes[NUM_PLANES];
FaceVerts box2_planes[NUM_PLANES];

for (size_t i = tid; i < N * M; i += stride) {
const size_t n = i / M; // box1 index
Expand Down
10 changes: 7 additions & 3 deletions pytorch3d/csrc/iou_box3d/iou_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ __device__ inline float3 FaceNormal(
auto normal = float3();
auto maxDist = -1;
for (auto v1 = vertices.begin(); v1 != vertices.end() - 1; ++v1) {
for (auto v2 = std::next(v1); v2 != vertices.end(); ++v2) {
for (auto v2 = v1 + 1; v2 != vertices.end(); ++v2) {
const auto v1ToCenter = *v1 - faceCenter;
const auto v2ToCenter = *v2 - faceCenter;
const auto dist = norm(cross(v1ToCenter, v2ToCenter));
Expand Down Expand Up @@ -472,8 +472,10 @@ __device__ inline bool IsCoplanarTriTri(
const bool check1 = abs(dot(tri1_n, tri2_n)) > 1 - dEpsilon;

// Compute most distant points
const auto [v1m, v2m] =
const auto v1mAndv2m =
ArgMaxVerts({tri1.v0, tri1.v1, tri1.v2}, {tri2.v0, tri2.v1, tri2.v2});
const auto v1m = std::get<0>(v1mAndv2m);
const auto v2m = std::get<1>(v1mAndv2m);

float3 n12m = v1m - v2m;
n12m = n12m / fmaxf(norm(n12m), kEpsilon);
Expand Down Expand Up @@ -506,8 +508,10 @@ __device__ inline bool IsCoplanarTriPlane(
const bool check1 = abs(dot(nt, normal)) > 1 - dEpsilon;

// Compute most distant points
const auto [v1m, v2m] = ArgMaxVerts(
const auto v1mAndv2m = ArgMaxVerts(
{tri.v0, tri.v1, tri.v2}, {plane.v0, plane.v1, plane.v2, plane.v3});
const auto v1m = std::get<0>(v1mAndv2m);
const auto v2m = std::get<1>(v1mAndv2m);

float3 n12m = v1m - v2m;
n12m = n12m / fmaxf(norm(n12m), kEpsilon);
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_extensions():
source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu"), recursive=True)
extension = CppExtension

extra_compile_args = {"cxx": []}
extra_compile_args = {"cxx": ["-std=c++14"]}
define_macros = []
include_dirs = [extensions_dir]

Expand All @@ -73,6 +73,8 @@ def get_extensions():
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
]
if os.name != "nt":
nvcc_args.append("-std=c++14")
if cub_home is None:
prefix = os.environ.get("CONDA_PREFIX", None)
if prefix is not None and os.path.isdir(prefix + "/include/cub"):
Expand Down

0 comments on commit 9577198

Please sign in to comment.