From 7944d24d4872bdb01b821450840049e28d0ce12b Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Wed, 5 Aug 2020 06:58:53 -0700 Subject: [PATCH] gather_scatter on CPU Summary: CPU implementation of the graph convolution op. Reviewed By: nikhilaravi, gkioxari Differential Revision: D21384361 fbshipit-source-id: bc96730e9727bb9aa1b0a232dcb82f0c0d12fe6b --- .../csrc/gather_scatter/gather_scatter.cu | 4 +-- .../csrc/gather_scatter/gather_scatter.h | 17 +++++---- .../gather_scatter/gather_scatter_cpu.cpp | 35 +++++++++++++++++++ tests/test_graph_conv.py | 31 ++++++++++------ 4 files changed, 68 insertions(+), 19 deletions(-) create mode 100644 pytorch3d/csrc/gather_scatter/gather_scatter_cpu.cpp diff --git a/pytorch3d/csrc/gather_scatter/gather_scatter.cu b/pytorch3d/csrc/gather_scatter/gather_scatter.cu index 0e744a115..bd0a97db8 100644 --- a/pytorch3d/csrc/gather_scatter/gather_scatter.cu +++ b/pytorch3d/csrc/gather_scatter/gather_scatter.cu @@ -44,8 +44,8 @@ __global__ void GatherScatterCudaKernel( } at::Tensor GatherScatterCuda( - const at::Tensor input, - const at::Tensor edges, + const at::Tensor& input, + const at::Tensor& edges, bool directed, bool backward) { // Check inputs are on the same device diff --git a/pytorch3d/csrc/gather_scatter/gather_scatter.h b/pytorch3d/csrc/gather_scatter/gather_scatter.h index 864e84ffd..20e5919ba 100644 --- a/pytorch3d/csrc/gather_scatter/gather_scatter.h +++ b/pytorch3d/csrc/gather_scatter/gather_scatter.h @@ -20,17 +20,22 @@ // Returns: // output: float32 Tensor of same shape as input. -// Cuda implementation. at::Tensor GatherScatterCuda( - const at::Tensor input, - const at::Tensor edges, + const at::Tensor& input, + const at::Tensor& edges, + bool directed, + bool backward); + +at::Tensor GatherScatterCpu( + const at::Tensor& input, + const at::Tensor& edges, bool directed, bool backward); // Exposed implementation. at::Tensor GatherScatter( - const at::Tensor input, - const at::Tensor edges, + const at::Tensor& input, + const at::Tensor& edges, bool directed, bool backward) { if (input.is_cuda() && edges.is_cuda()) { @@ -42,5 +47,5 @@ at::Tensor GatherScatter( AT_ERROR("Not compiled with GPU support."); #endif } - AT_ERROR("Not implemented on the CPU"); + return GatherScatterCpu(input, edges, directed, backward); } diff --git a/pytorch3d/csrc/gather_scatter/gather_scatter_cpu.cpp b/pytorch3d/csrc/gather_scatter/gather_scatter_cpu.cpp new file mode 100644 index 000000000..ba24b8b11 --- /dev/null +++ b/pytorch3d/csrc/gather_scatter/gather_scatter_cpu.cpp @@ -0,0 +1,35 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +#include + +at::Tensor GatherScatterCpu( + const at::Tensor& input, + const at::Tensor& edges, + bool directed, + bool backward) { + const auto num_vertices = input.size(0); + const auto input_feature_dim = input.size(1); + const auto num_edges = edges.size(0); + + auto output = at::zeros({num_vertices, input_feature_dim}, input.options()); + + auto input_a = input.accessor(); + auto edges_a = edges.accessor(); + auto output_a = output.accessor(); + const int v0_idx = backward ? 1 : 0; + const int v1_idx = backward ? 0 : 1; + + for (int e = 0; e < num_edges; ++e) { + // Get indices of vertices which form the edge. + const int64_t v0 = edges_a[e][v0_idx]; + const int64_t v1 = edges_a[e][v1_idx]; + + for (int d = 0; d < input_feature_dim; ++d) { + output_a[v0][d] += input_a[v1][d]; + if (!directed) { + output_a[v1][d] += input_a[v0][d]; + } + } + } + return output; +} diff --git a/tests/test_graph_conv.py b/tests/test_graph_conv.py index dd64d82d5..75ff4f909 100644 --- a/tests/test_graph_conv.py +++ b/tests/test_graph_conv.py @@ -101,17 +101,24 @@ def test_backward(self): mesh = ico_sphere() verts = mesh.verts_packed() edges = mesh.edges_packed() + verts_cpu = verts.clone() + edges_cpu = edges.clone() verts_cuda = verts.clone().to(device) edges_cuda = edges.clone().to(device) verts.requires_grad = True + verts_cpu.requires_grad = True verts_cuda.requires_grad = True neighbor_sums_cuda = gather_scatter(verts_cuda, edges_cuda, False) + neighbor_sums_cpu = gather_scatter(verts_cpu, edges_cpu, False) neighbor_sums = gather_scatter_python(verts, edges, False) - neighbor_sums_cuda.sum().backward() - neighbor_sums.sum().backward() + randoms = torch.rand_like(neighbor_sums) + (neighbor_sums_cuda * randoms.cuda()).sum().backward() + (neighbor_sums_cpu * randoms).sum().backward() + (neighbor_sums * randoms).sum().backward() - self.assertClose(verts.grad.cpu(), verts_cuda.grad.cpu()) + self.assertClose(verts.grad, verts_cuda.grad.cpu()) + self.assertClose(verts.grad, verts_cpu.grad) def test_repr(self): conv = GraphConv(32, 64, directed=True) @@ -141,22 +148,24 @@ def test_gather_scatter(self): w0 = nn.Linear(3, 1) input = w0(verts) - # output - output_cpu = gather_scatter_python(input, edges, False) + # undirected + output_python = gather_scatter_python(input, edges, False) output_cuda = _C.gather_scatter( input.to(device=device), edges.to(device=device), False, False ) - self.assertClose(output_cuda.cpu(), output_cpu) - with self.assertRaises(Exception) as err: - _C.gather_scatter(input.cpu(), edges.cpu(), False, False) - self.assertTrue("Not implemented on the CPU" in str(err.exception)) + self.assertClose(output_cuda.cpu(), output_python) + + output_cpu = _C.gather_scatter(input.cpu(), edges.cpu(), False, False) + self.assertClose(output_cpu, output_python) # directed - output_cpu = gather_scatter_python(input, edges, True) + output_python = gather_scatter_python(input, edges, True) output_cuda = _C.gather_scatter( input.to(device=device), edges.to(device=device), True, False ) - self.assertClose(output_cuda.cpu(), output_cpu) + self.assertClose(output_cuda.cpu(), output_python) + output_cpu = _C.gather_scatter(input.cpu(), edges.cpu(), True, False) + self.assertClose(output_cpu, output_python) @staticmethod def graph_conv_forward_backward(