Skip to content

Commit

Permalink
gather_scatter on CPU
Browse files Browse the repository at this point in the history
Summary: CPU implementation of the graph convolution op.

Reviewed By: nikhilaravi, gkioxari

Differential Revision: D21384361

fbshipit-source-id: bc96730e9727bb9aa1b0a232dcb82f0c0d12fe6b
  • Loading branch information
bottler authored and facebook-github-bot committed Aug 5, 2020
1 parent 4872a2c commit 7944d24
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 19 deletions.
4 changes: 2 additions & 2 deletions pytorch3d/csrc/gather_scatter/gather_scatter.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ __global__ void GatherScatterCudaKernel(
}

at::Tensor GatherScatterCuda(
const at::Tensor input,
const at::Tensor edges,
const at::Tensor& input,
const at::Tensor& edges,
bool directed,
bool backward) {
// Check inputs are on the same device
Expand Down
17 changes: 11 additions & 6 deletions pytorch3d/csrc/gather_scatter/gather_scatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,22 @@
// Returns:
// output: float32 Tensor of same shape as input.

// Cuda implementation.
at::Tensor GatherScatterCuda(
const at::Tensor input,
const at::Tensor edges,
const at::Tensor& input,
const at::Tensor& edges,
bool directed,
bool backward);

at::Tensor GatherScatterCpu(
const at::Tensor& input,
const at::Tensor& edges,
bool directed,
bool backward);

// Exposed implementation.
at::Tensor GatherScatter(
const at::Tensor input,
const at::Tensor edges,
const at::Tensor& input,
const at::Tensor& edges,
bool directed,
bool backward) {
if (input.is_cuda() && edges.is_cuda()) {
Expand All @@ -42,5 +47,5 @@ at::Tensor GatherScatter(
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("Not implemented on the CPU");
return GatherScatterCpu(input, edges, directed, backward);
}
35 changes: 35 additions & 0 deletions pytorch3d/csrc/gather_scatter/gather_scatter_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

#include <ATen/ATen.h>

at::Tensor GatherScatterCpu(
const at::Tensor& input,
const at::Tensor& edges,
bool directed,
bool backward) {
const auto num_vertices = input.size(0);
const auto input_feature_dim = input.size(1);
const auto num_edges = edges.size(0);

auto output = at::zeros({num_vertices, input_feature_dim}, input.options());

auto input_a = input.accessor<float, 2>();
auto edges_a = edges.accessor<int64_t, 2>();
auto output_a = output.accessor<float, 2>();
const int v0_idx = backward ? 1 : 0;
const int v1_idx = backward ? 0 : 1;

for (int e = 0; e < num_edges; ++e) {
// Get indices of vertices which form the edge.
const int64_t v0 = edges_a[e][v0_idx];
const int64_t v1 = edges_a[e][v1_idx];

for (int d = 0; d < input_feature_dim; ++d) {
output_a[v0][d] += input_a[v1][d];
if (!directed) {
output_a[v1][d] += input_a[v0][d];
}
}
}
return output;
}
31 changes: 20 additions & 11 deletions tests/test_graph_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,24 @@ def test_backward(self):
mesh = ico_sphere()
verts = mesh.verts_packed()
edges = mesh.edges_packed()
verts_cpu = verts.clone()
edges_cpu = edges.clone()
verts_cuda = verts.clone().to(device)
edges_cuda = edges.clone().to(device)
verts.requires_grad = True
verts_cpu.requires_grad = True
verts_cuda.requires_grad = True

neighbor_sums_cuda = gather_scatter(verts_cuda, edges_cuda, False)
neighbor_sums_cpu = gather_scatter(verts_cpu, edges_cpu, False)
neighbor_sums = gather_scatter_python(verts, edges, False)
neighbor_sums_cuda.sum().backward()
neighbor_sums.sum().backward()
randoms = torch.rand_like(neighbor_sums)
(neighbor_sums_cuda * randoms.cuda()).sum().backward()
(neighbor_sums_cpu * randoms).sum().backward()
(neighbor_sums * randoms).sum().backward()

self.assertClose(verts.grad.cpu(), verts_cuda.grad.cpu())
self.assertClose(verts.grad, verts_cuda.grad.cpu())
self.assertClose(verts.grad, verts_cpu.grad)

def test_repr(self):
conv = GraphConv(32, 64, directed=True)
Expand Down Expand Up @@ -141,22 +148,24 @@ def test_gather_scatter(self):
w0 = nn.Linear(3, 1)
input = w0(verts)

# output
output_cpu = gather_scatter_python(input, edges, False)
# undirected
output_python = gather_scatter_python(input, edges, False)
output_cuda = _C.gather_scatter(
input.to(device=device), edges.to(device=device), False, False
)
self.assertClose(output_cuda.cpu(), output_cpu)
with self.assertRaises(Exception) as err:
_C.gather_scatter(input.cpu(), edges.cpu(), False, False)
self.assertTrue("Not implemented on the CPU" in str(err.exception))
self.assertClose(output_cuda.cpu(), output_python)

output_cpu = _C.gather_scatter(input.cpu(), edges.cpu(), False, False)
self.assertClose(output_cpu, output_python)

# directed
output_cpu = gather_scatter_python(input, edges, True)
output_python = gather_scatter_python(input, edges, True)
output_cuda = _C.gather_scatter(
input.to(device=device), edges.to(device=device), True, False
)
self.assertClose(output_cuda.cpu(), output_cpu)
self.assertClose(output_cuda.cpu(), output_python)
output_cpu = _C.gather_scatter(input.cpu(), edges.cpu(), True, False)
self.assertClose(output_cpu, output_python)

@staticmethod
def graph_conv_forward_backward(
Expand Down

0 comments on commit 7944d24

Please sign in to comment.