Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Khop Graph Sampler API #39146

Merged
merged 47 commits into from
Jan 27, 2022
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
2f5f235
add the test case for the UVA
wawltor Dec 14, 2021
212e742
Merge branch 'cuda_uva' of https://github.com/wawltor/Paddle into mul…
DesmonDay Dec 29, 2021
adb1886
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
DesmonDay Dec 29, 2021
83f7d0c
add the context load for the uva
wawltor Dec 31, 2021
31101e4
Merge branch 'cuda_uva' of https://github.com/wawltor/Paddle into mul…
DesmonDay Jan 6, 2022
93a9bd2
Add graph_sample kernel
DesmonDay Jan 10, 2022
4c8e430
Add graph_sample commit
DesmonDay Jan 13, 2022
0730d2a
add new commit for graph_sample
DesmonDay Jan 14, 2022
ce915f0
add unsigned long long int
DesmonDay Jan 14, 2022
28864b2
delete some remarks
DesmonDay Jan 17, 2022
7cb088a
add cpu version
DesmonDay Jan 18, 2022
f3fb01c
add cuda eids
DesmonDay Jan 18, 2022
c1f2284
add cpu eids
DesmonDay Jan 18, 2022
bc22884
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
DesmonDay Jan 19, 2022
5e60e89
delete _uva
DesmonDay Jan 19, 2022
6a442a8
optimize speed: emplace_back, last_layer
DesmonDay Jan 20, 2022
fc8665c
add to_uva_tensor
DesmonDay Jan 21, 2022
eebe1e3
add cpu return_eids choice
DesmonDay Jan 21, 2022
b69cc87
add gpu return_eids choice
DesmonDay Jan 21, 2022
ca41d39
add cpu reindex_nodes
DesmonDay Jan 22, 2022
8ab2f39
add gpu reindex_nodes
DesmonDay Jan 22, 2022
b800396
rename op and add OMP for cpu
DesmonDay Jan 22, 2022
1e5ff06
add incubate api
DesmonDay Jan 22, 2022
1b432ef
fix the compile problem for the PADDLE_ENFORE and different device
wawltor Jan 23, 2022
08cc35e
fix the rcom and windows compile problem
wawltor Jan 23, 2022
26b16a3
add unittest for graph_sample_neighbors
DesmonDay Jan 23, 2022
1cdb372
fix cpu unittest and unique problem
DesmonDay Jan 23, 2022
5159657
Merge branch 'debug_graph_sample' of https://github.com/DesmonDay/Pad…
DesmonDay Jan 23, 2022
35b324f
fix uva unittest, fix cuda unique problem
DesmonDay Jan 23, 2022
08b26a1
fix the windows compile problem
wawltor Jan 23, 2022
aef80d6
fix the windows rand_r compile problem
wawltor Jan 24, 2022
b60b73a
add correct unittest, add src_eids dispensable
DesmonDay Jan 24, 2022
8cf118e
Merge branch 'debug_graph_sample' of https://github.com/DesmonDay/Pad…
DesmonDay Jan 24, 2022
4df0a8d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
DesmonDay Jan 24, 2022
2ce02d8
delete black
DesmonDay Jan 24, 2022
d041655
combine uva unittest
DesmonDay Jan 24, 2022
6be601d
mv Sample_index to Sample_Index; check input shape; fix random sample…
DesmonDay Jan 24, 2022
4e25f93
delete memset & cudaMemset
DesmonDay Jan 24, 2022
ee9e11b
fix according to PR comments
DesmonDay Jan 24, 2022
4ebf80b
fix rocm ci
DesmonDay Jan 24, 2022
cbe8c8b
modify function names according to the specification
DesmonDay Jan 24, 2022
cb65a63
fix windows_openblas ci
DesmonDay Jan 24, 2022
33705a3
refine annotations, fix windows unittest, add default value for uva d…
DesmonDay Jan 24, 2022
e276b29
fix rocm ci
DesmonDay Jan 24, 2022
5a665fc
rename graph_sample_neighbors as graph_khop_sampler, add incubate api…
DesmonDay Jan 26, 2022
6ab1421
add data type
DesmonDay Jan 26, 2022
d3e36c8
fix conflict
DesmonDay Jan 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions paddle/fluid/operators/graph_sample_neighbors_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
DesmonDay marked this conversation as resolved.
Show resolved Hide resolved

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/graph_sample_neighbors_op.h"

namespace paddle {
namespace operators {

class GraphSampleNeighborsOP : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Src"), "Input", "Src",
"GraphSampleNeighbors");
OP_INOUT_CHECK(ctx->HasInput("Src_Eids"), "Input", "Src_Eids",
"GraphSampleNeighbors");
OP_INOUT_CHECK(ctx->HasInput("Dst_Count"), "Input", "Dst_Count",
"GraphSampleNeighbors");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "GraphSampleNeighbors");
OP_INOUT_CHECK(ctx->HasOutput("Out_Src"), "Output", "Out_Src",
"GraphSampleNeighbors");
OP_INOUT_CHECK(ctx->HasOutput("Out_Dst"), "Output", "Out_Dst",
"GraphSampleNeighbors");
OP_INOUT_CHECK(ctx->HasOutput("Sample_index"), "Output", "Sample_index",
"GraphSampleNeighbors");
OP_INOUT_CHECK(ctx->HasOutput("Reindex_X"), "Output", "Reindex_X",
"GraphSampleNeighbors");
// 是否限制所有输入输出均为1维向量,或者2维向量第二维为1.
DesmonDay marked this conversation as resolved.
Show resolved Hide resolved

const std::vector<int>& sample_sizes =
ctx->Attrs().Get<std::vector<int>>("sample_sizes");
PADDLE_ENFORCE_EQ(
!sample_sizes.empty(), true,
platform::errors::InvalidArgument(
"The parameter 'sample_sizes' in GraphSampleOp must be set. "
"But received 'sample_sizes' is empty."));
const bool& return_eids = ctx->Attrs().Get<bool>("return_eids");
if (return_eids) {
OP_INOUT_CHECK(ctx->HasOutput("Out_Eids"), "Output", "Out_Eids",
"GraphSampleNeighbors");
ctx->SetOutputDim("Out_Eids", {-1});
}

ctx->SetOutputDim("Out_Src", {-1, 1});
ctx->SetOutputDim("Out_Dst", {-1, 1});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有个疑问,咱们的输出的out_src, out_dst一定需要是二维的吗?需要和PGL目前的版本一致吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

暂且存疑

ctx->SetOutputDim("Sample_index", {-1});

auto dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Reindex_X", dims);
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Src"),
ctx.device_context());
}
};

class GraphSampleNeighborsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Src", "The src index tensor after sorted by dst.");
AddInput("Src_Eids", "The eids of the input graph edges.");
AddInput("Dst_Count",
"The indegree cumsum of dst intex, starts from 0, end with number "
DesmonDay marked this conversation as resolved.
Show resolved Hide resolved
"of edges");
AddInput("X", "The input center nodes index tensor.");
AddOutput("Out_Src",
"The output src edges tensor after sampling and reindex.");
AddOutput("Out_Dst",
"The output dst edges tensor after sampling and reindex.");
AddOutput("Sample_index",
DesmonDay marked this conversation as resolved.
Show resolved Hide resolved
"The original index of the sampling nodes and center nodes.");
AddOutput("Reindex_X", "The reindex node id of the input nodes.");
AddOutput("Out_Eids", "The eids of the sample edges.").AsIntermediate();
AddAttr<std::vector<int>>(
"sample_sizes", "The sample sizes of graph sample neighbors method.")
.SetDefault({});
AddAttr<bool>("return_eids",
"Whether to return the eids of the sample edges.")
.SetDefault(false);
AddComment(R"DOC(
Graph Learning Sampling Neighbors operator, for graphsage sampling method.

)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;

REGISTER_OPERATOR(graph_sample_neighbors, ops::GraphSampleNeighborsOP,
ops::GraphSampleNeighborsOpMaker);
REGISTER_OP_CPU_KERNEL(graph_sample_neighbors,
ops::GraphSampleNeighborsOpKernel<CPU, int>,
DesmonDay marked this conversation as resolved.
Show resolved Hide resolved
ops::GraphSampleNeighborsOpKernel<CPU, int64_t>);
Loading