Skip to content

Commit

Permalink
add min_triangle_area argument to IsInsideTriangle
Browse files Browse the repository at this point in the history
Summary:
1. changed IsInsideTriangle in geometry_utils to take in min_triangle_area parameter instead of hardcoded value
2. updated point_mesh_cpu.cpp and point_mesh_cuda.[h/cu] to adapt to changes in geometry_utils function signatures
3. updated point_mesh_distance.py and test_point_mesh_distance.py to modify _C. calls

Reviewed By: bottler

Differential Revision: D34459764

fbshipit-source-id: 0549e78713c6d68f03d85fb597a13dd88e09b686
  • Loading branch information
winnie1994 authored and facebook-github-bot committed Feb 25, 2022
1 parent 4d043fc commit 471b126
Show file tree
Hide file tree
Showing 7 changed files with 344 additions and 134 deletions.
114 changes: 72 additions & 42 deletions pytorch3d/csrc/point_mesh/point_mesh_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,29 +57,33 @@ void IncrementPoint(at::TensorAccessor<T, 1>&& t, const vec3<T>& point) {
template <typename T>
T HullDistance(
const std::array<vec3<T>, 1>& a,
const std::array<vec3<T>, 2>& b) {
const std::array<vec3<T>, 2>& b,
const double /*min_triangle_area*/) {
using std::get;
return PointLine3DistanceForward(get<0>(a), get<0>(b), get<1>(b));
}
template <typename T>
T HullDistance(
const std::array<vec3<T>, 1>& a,
const std::array<vec3<T>, 3>& b) {
const std::array<vec3<T>, 3>& b,
const double min_triangle_area) {
using std::get;
return PointTriangle3DistanceForward(
get<0>(a), get<0>(b), get<1>(b), get<2>(b));
get<0>(a), get<0>(b), get<1>(b), get<2>(b), min_triangle_area);
}
template <typename T>
T HullDistance(
const std::array<vec3<T>, 2>& a,
const std::array<vec3<T>, 1>& b) {
return HullDistance(b, a);
const std::array<vec3<T>, 1>& b,
const double /*min_triangle_area*/) {
return HullDistance(b, a, 1);
}
template <typename T>
T HullDistance(
const std::array<vec3<T>, 3>& a,
const std::array<vec3<T>, 1>& b) {
return HullDistance(b, a);
const std::array<vec3<T>, 1>& b,
const double min_triangle_area) {
return HullDistance(b, a, min_triangle_area);
}

template <typename T>
Expand All @@ -88,7 +92,8 @@ void HullHullDistanceBackward(
const std::array<vec3<T>, 2>& b,
T grad_dist,
at::TensorAccessor<T, 1>&& grad_a,
at::TensorAccessor<T, 2>&& grad_b) {
at::TensorAccessor<T, 2>&& grad_b,
const double /*min_triangle_area*/) {
using std::get;
auto res =
PointLine3DistanceBackward(get<0>(a), get<0>(b), get<1>(b), grad_dist);
Expand All @@ -102,10 +107,11 @@ void HullHullDistanceBackward(
const std::array<vec3<T>, 3>& b,
T grad_dist,
at::TensorAccessor<T, 1>&& grad_a,
at::TensorAccessor<T, 2>&& grad_b) {
at::TensorAccessor<T, 2>&& grad_b,
const double min_triangle_area) {
using std::get;
auto res = PointTriangle3DistanceBackward(
get<0>(a), get<0>(b), get<1>(b), get<2>(b), grad_dist);
get<0>(a), get<0>(b), get<1>(b), get<2>(b), grad_dist, min_triangle_area);
IncrementPoint(std::move(grad_a), get<0>(res));
IncrementPoint(grad_b[0], get<1>(res));
IncrementPoint(grad_b[1], get<2>(res));
Expand All @@ -117,19 +123,21 @@ void HullHullDistanceBackward(
const std::array<vec3<T>, 1>& b,
T grad_dist,
at::TensorAccessor<T, 2>&& grad_a,
at::TensorAccessor<T, 1>&& grad_b) {
at::TensorAccessor<T, 1>&& grad_b,
const double min_triangle_area) {
return HullHullDistanceBackward(
b, a, grad_dist, std::move(grad_b), std::move(grad_a));
b, a, grad_dist, std::move(grad_b), std::move(grad_a), min_triangle_area);
}
template <typename T>
void HullHullDistanceBackward(
const std::array<vec3<T>, 2>& a,
const std::array<vec3<T>, 1>& b,
T grad_dist,
at::TensorAccessor<T, 2>&& grad_a,
at::TensorAccessor<T, 1>&& grad_b) {
at::TensorAccessor<T, 1>&& grad_b,
const double /*min_triangle_area*/) {
return HullHullDistanceBackward(
b, a, grad_dist, std::move(grad_b), std::move(grad_a));
b, a, grad_dist, std::move(grad_b), std::move(grad_a), 1);
}

template <int H>
Expand All @@ -150,7 +158,8 @@ std::tuple<at::Tensor, at::Tensor> HullHullDistanceForwardCpu(
const at::Tensor& as,
const at::Tensor& as_first_idx,
const at::Tensor& bs,
const at::Tensor& bs_first_idx) {
const at::Tensor& bs_first_idx,
const double min_triangle_area) {
const int64_t A_N = as.size(0);
const int64_t B_N = bs.size(0);
const int64_t BATCHES = as_first_idx.size(0);
Expand Down Expand Up @@ -190,7 +199,8 @@ std::tuple<at::Tensor, at::Tensor> HullHullDistanceForwardCpu(
size_t min_idx = 0;
auto a = ExtractHull<H1>(as_a[a_n]);
for (int64_t b_n = b_batch_start; b_n < b_batch_end; ++b_n) {
float dist = HullDistance(a, ExtractHull<H2>(bs_a[b_n]));
float dist =
HullDistance(a, ExtractHull<H2>(bs_a[b_n]), min_triangle_area);
if (dist <= min_dist) {
min_dist = dist;
min_idx = b_n;
Expand All @@ -208,7 +218,8 @@ std::tuple<at::Tensor, at::Tensor> HullHullDistanceBackwardCpu(
const at::Tensor& as,
const at::Tensor& bs,
const at::Tensor& idx_bs,
const at::Tensor& grad_dists) {
const at::Tensor& grad_dists,
const double min_triangle_area) {
const int64_t A_N = as.size(0);

TORCH_CHECK(idx_bs.size(0) == A_N);
Expand All @@ -230,15 +241,21 @@ std::tuple<at::Tensor, at::Tensor> HullHullDistanceBackwardCpu(
auto a = ExtractHull<H1>(as_a[a_n]);
auto b = ExtractHull<H2>(bs_a[idx_bs_a[a_n]]);
HullHullDistanceBackward(
a, b, grad_dists_a[a_n], grad_as_a[a_n], grad_bs_a[idx_bs_a[a_n]]);
a,
b,
grad_dists_a[a_n],
grad_as_a[a_n],
grad_bs_a[idx_bs_a[a_n]],
min_triangle_area);
}
return std::make_tuple(grad_as, grad_bs);
}

template <int H>
torch::Tensor PointHullArrayDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& bs) {
const torch::Tensor& bs,
const double min_triangle_area) {
const int64_t P = points.size(0);
const int64_t B_N = bs.size(0);

Expand All @@ -254,7 +271,7 @@ torch::Tensor PointHullArrayDistanceForwardCpu(
auto dest = dists_a[p];
for (int64_t b_n = 0; b_n < B_N; ++b_n) {
auto b = ExtractHull<H>(bs_a[b_n]);
dest[b_n] = HullDistance(point, b);
dest[b_n] = HullDistance(point, b, min_triangle_area);
}
}
return dists;
Expand All @@ -264,7 +281,8 @@ template <int H>
std::tuple<at::Tensor, at::Tensor> PointHullArrayDistanceBackwardCpu(
const at::Tensor& points,
const at::Tensor& bs,
const at::Tensor& grad_dists) {
const at::Tensor& grad_dists,
const double min_triangle_area) {
const int64_t P = points.size(0);
const int64_t B_N = bs.size(0);

Expand All @@ -287,7 +305,12 @@ std::tuple<at::Tensor, at::Tensor> PointHullArrayDistanceBackwardCpu(
for (int64_t b_n = 0; b_n < B_N; ++b_n) {
auto b = ExtractHull<H>(bs_a[b_n]);
HullHullDistanceBackward(
point, b, grad_dist[b_n], std::move(grad_point), grad_bs_a[b_n]);
point,
b,
grad_dist[b_n],
std::move(grad_point),
grad_bs_a[b_n],
min_triangle_area);
}
}
return std::make_tuple(grad_points, grad_bs);
Expand All @@ -299,63 +322,70 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& tris,
const torch::Tensor& tris_first_idx) {
const torch::Tensor& tris_first_idx,
const double min_triangle_area) {
return HullHullDistanceForwardCpu<1, 3>(
points, points_first_idx, tris, tris_first_idx);
points, points_first_idx, tris, tris_first_idx, min_triangle_area);
}

std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackwardCpu(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& idx_points,
const torch::Tensor& grad_dists) {
const torch::Tensor& grad_dists,
const double min_triangle_area) {
return HullHullDistanceBackwardCpu<1, 3>(
points, tris, idx_points, grad_dists);
points, tris, idx_points, grad_dists, min_triangle_area);
}

std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& tris,
const torch::Tensor& tris_first_idx) {
const torch::Tensor& tris_first_idx,
const double min_triangle_area) {
return HullHullDistanceForwardCpu<3, 1>(
tris, tris_first_idx, points, points_first_idx);
tris, tris_first_idx, points, points_first_idx, min_triangle_area);
}

std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackwardCpu(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& idx_tris,
const torch::Tensor& grad_dists) {
auto res =
HullHullDistanceBackwardCpu<3, 1>(tris, points, idx_tris, grad_dists);
const torch::Tensor& grad_dists,
const double min_triangle_area) {
auto res = HullHullDistanceBackwardCpu<3, 1>(
tris, points, idx_tris, grad_dists, min_triangle_area);
return std::make_tuple(std::get<1>(res), std::get<0>(res));
}

torch::Tensor PointEdgeArrayDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& segms) {
return PointHullArrayDistanceForwardCpu<2>(points, segms);
return PointHullArrayDistanceForwardCpu<2>(points, segms, 1);
}

std::tuple<at::Tensor, at::Tensor> PointFaceArrayDistanceBackwardCpu(
const at::Tensor& points,
const at::Tensor& tris,
const at::Tensor& grad_dists) {
return PointHullArrayDistanceBackwardCpu<3>(points, tris, grad_dists);
const at::Tensor& grad_dists,
const double min_triangle_area) {
return PointHullArrayDistanceBackwardCpu<3>(
points, tris, grad_dists, min_triangle_area);
}

torch::Tensor PointFaceArrayDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& tris) {
return PointHullArrayDistanceForwardCpu<3>(points, tris);
const torch::Tensor& tris,
const double min_triangle_area) {
return PointHullArrayDistanceForwardCpu<3>(points, tris, min_triangle_area);
}

std::tuple<at::Tensor, at::Tensor> PointEdgeArrayDistanceBackwardCpu(
const at::Tensor& points,
const at::Tensor& segms,
const at::Tensor& grad_dists) {
return PointHullArrayDistanceBackwardCpu<2>(points, segms, grad_dists);
return PointHullArrayDistanceBackwardCpu<2>(points, segms, grad_dists, 1);
}

std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCpu(
Expand All @@ -365,7 +395,7 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCpu(
const torch::Tensor& segms_first_idx,
const int64_t /*max_points*/) {
return HullHullDistanceForwardCpu<1, 2>(
points, points_first_idx, segms, segms_first_idx);
points, points_first_idx, segms, segms_first_idx, 1);
}

std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCpu(
Expand All @@ -374,7 +404,7 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCpu(
const torch::Tensor& idx_points,
const torch::Tensor& grad_dists) {
return HullHullDistanceBackwardCpu<1, 2>(
points, segms, idx_points, grad_dists);
points, segms, idx_points, grad_dists, 1);
}

std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCpu(
Expand All @@ -384,15 +414,15 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCpu(
const torch::Tensor& segms_first_idx,
const int64_t /*max_segms*/) {
return HullHullDistanceForwardCpu<2, 1>(
segms, segms_first_idx, points, points_first_idx);
segms, segms_first_idx, points, points_first_idx, 1);
}

std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCpu(
const torch::Tensor& points,
const torch::Tensor& segms,
const torch::Tensor& idx_segms,
const torch::Tensor& grad_dists) {
auto res =
HullHullDistanceBackwardCpu<2, 1>(segms, points, idx_segms, grad_dists);
auto res = HullHullDistanceBackwardCpu<2, 1>(
segms, points, idx_segms, grad_dists, 1);
return std::make_tuple(std::get<1>(res), std::get<0>(res));
}
Loading

0 comments on commit 471b126

Please sign in to comment.