Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GraphEdgeBlob #7

Merged
merged 3 commits into from
Mar 20, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions paddle/fluid/distributed/table/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
set_property(GLOBAL PROPERTY TABLE_DEPS string_helper)

get_property(TABLE_DEPS GLOBAL PROPERTY TABLE_DEPS)
set_source_files_properties(graph_edge.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(graph_edge SRCS graph_edge.cc)
set_source_files_properties(weighted_sampler.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(WeightedSampler SRCS weighted_sampler.cc)
cc_library(WeightedSampler SRCS weighted_sampler.cc DEPS graph_edge)
set_source_files_properties(graph_node.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(graph_node SRCS graph_node.cc DEPS WeightedSampler)
set_source_files_properties(common_dense_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
Expand All @@ -11,7 +13,7 @@ set_source_files_properties(sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTR
set_source_files_properties(barrier_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(common_graph_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})

cc_library(common_table SRCS common_sparse_table.cc common_dense_table.cc sparse_geo_table.cc barrier_table.cc common_graph_table.cc DEPS ${TABLE_DEPS} graph_node device_context string_helper simple_threadpool xxhash generator)
cc_library(common_table SRCS common_sparse_table.cc common_dense_table.cc sparse_geo_table.cc barrier_table.cc common_graph_table.cc DEPS ${TABLE_DEPS} graph_edge graph_node device_context string_helper simple_threadpool xxhash generator)

set_source_files_properties(tensor_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(tensor_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
Expand Down
18 changes: 10 additions & 8 deletions paddle/fluid/distributed/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ GraphNode *GraphShard::add_node(uint64_t id, std::string feature) {
return bucket.back();
}

void GraphShard::add_neighboor(uint64_t id, GraphEdge *edge) {
add_node(id, std::string(""))->add_edge(edge);
void GraphShard::add_neighboor(uint64_t id, uint64_t dst_id, float weight) {
add_node(id, std::string(""))->add_edge(dst_id, weight);
}

GraphNode *GraphShard::find_node(uint64_t id) {
Expand Down Expand Up @@ -147,6 +147,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
auto paths = paddle::string::split_string<std::string>(path, ";");
int count = 0;
std::string sample_type = "random";
bool is_weighted = false;

for (auto path : paths) {
std::ifstream file(path);
Expand All @@ -164,6 +165,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
if (values.size() == 3) {
weight = std::stof(values[2]);
sample_type = "weighted";
is_weighted = true;
}

size_t src_shard_id = src_id % shard_num;
Expand All @@ -175,8 +177,8 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
}

size_t index = src_shard_id - shard_start;
GraphEdge *edge = new GraphEdge(dst_id, weight);
shards[index].add_neighboor(src_id, edge);
shards[index].add_node(src_id, std::string(""))->build_edges(is_weighted);
shards[index].add_neighboor(src_id, dst_id, weight);
}
}
VLOG(0) << "Load Finished Total Edge Count " << count;
Expand Down Expand Up @@ -287,17 +289,17 @@ int GraphTable::random_sample_neighboors(
actual_size = 0;
return 0;
}
std::vector<GraphEdge *> res = node->sample_k(sample_size);
std::vector<int> res = node->sample_k(sample_size);
actual_size =
res.size() * (GraphNode::id_size + GraphNode::weight_size);
int offset = 0;
uint64_t id;
float weight;
char *buffer_addr = new char[actual_size];
buffer.reset(buffer_addr);
for (auto &x : res) {
id = x->get_id();
weight = x->get_weight();
for (int &x : res) {
id = node->get_neighbor_id(x);
weight = node->get_neighbor_weight(x);
memcpy(buffer_addr + offset, &id, GraphNode::id_size);
offset += GraphNode::id_size;
memcpy(buffer_addr + offset, &weight, GraphNode::weight_size);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class GraphShard {
}
GraphNode *add_node(uint64_t id, std::string feature);
GraphNode *find_node(uint64_t id);
void add_neighboor(uint64_t id, GraphEdge *edge);
void add_neighboor(uint64_t id, uint64_t dst_id, float weight);
// std::unordered_map<uint64_t, std::list<GraphNode *>::iterator>
std::unordered_map<uint64_t, int> get_node_location() {
return node_location;
Expand Down
30 changes: 30 additions & 0 deletions paddle/fluid/distributed/table/graph_edge.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/table/graph_edge.h"
#include <cstring>
namespace paddle {
namespace distributed {

void GraphEdgeBlob::add_edge(uint64_t id, float weight=1){
id_arr.push_back(id);
}

void WeightedGraphEdgeBlob::add_edge(uint64_t id, float weight=1){
id_arr.push_back(id);
weight_arr.push_back(weight);
}

}
}
47 changes: 47 additions & 0 deletions paddle/fluid/distributed/table/graph_edge.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <vector>
#include <cstdint>
namespace paddle {
namespace distributed {


class GraphEdgeBlob {
public:
GraphEdgeBlob() {}
virtual ~GraphEdgeBlob() {}
const size_t size() {return id_arr.size();}
//virtual void add_edge(GraphEdge e);
virtual void add_edge(uint64_t id, float weight);
const uint64_t get_id(int idx) { return id_arr[idx]; }
virtual const float get_weight(int idx) { return 1; }
protected:
std::vector<uint64_t> id_arr;
};

class WeightedGraphEdgeBlob: public GraphEdgeBlob{
public:
WeightedGraphEdgeBlob() {}
virtual ~WeightedGraphEdgeBlob() {}
//virtual void add_edge(WeightedGraphEdge e);
virtual void add_edge(uint64_t id, float weight);
virtual const float get_weight(int idx) { return weight_arr[idx]; }
protected:
std::vector<float> weight_arr;
};

}
}
26 changes: 23 additions & 3 deletions paddle/fluid/distributed/table/graph_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,41 @@
#include <cstring>
namespace paddle {
namespace distributed {


GraphNode::~GraphNode() {
if (sampler != nullptr){
delete sampler;
sampler = nullptr;
}
if (edges != nullptr){
delete edges;
edges = nullptr;
}
}

int GraphNode::weight_size = sizeof(float);
int GraphNode::id_size = sizeof(uint64_t);
int GraphNode::int_size = sizeof(int);
int GraphNode::get_size(bool need_feature) {
return id_size + int_size + (need_feature ? feature.size() : 0);
}
void GraphNode::build_edges(bool is_weighted) {
if (edges == nullptr){
if (is_weighted == true){
edges = new WeightedGraphEdgeBlob();
} else {
edges = new GraphEdgeBlob();
}
}
}
void GraphNode::build_sampler(std::string sample_type) {
if (sample_type == "random"){
sampler = new RandomSampler();
} else if (sample_type == "weighted"){
sampler = new WeightedSampler();
}
//GraphEdge** arr = edges.data();
//sampler->build((WeightedObject**)arr, 0, edges.size());
sampler->build((std::vector<WeightedObject*>*)&edges);
sampler->build(edges);
}
void GraphNode::to_buffer(char* buffer, bool need_feature) {
int size = get_size(need_feature);
Expand Down
36 changes: 10 additions & 26 deletions paddle/fluid/distributed/table/graph_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,49 +17,33 @@
#include "paddle/fluid/distributed/table/weighted_sampler.h"
namespace paddle {
namespace distributed {
// enum GraphNodeType { user = 0, item = 1, query = 2, unknown = 3 };
class GraphEdge : public WeightedObject {
public:
GraphEdge() {}
GraphEdge(uint64_t id, float weight) : id(id), weight(weight) {}
uint64_t get_id() { return id; }
float get_weight() { return weight; }
uint64_t id;
float weight;
};

class GraphNode {
public:
GraphNode() { sampler = NULL; }
GraphNode(): sampler(nullptr), edges(nullptr) { }
GraphNode(uint64_t id, std::string feature)
: id(id), feature(feature), sampler(NULL) {}
virtual ~GraphNode() {}
std::vector<GraphEdge *> get_graph_edge() { return edges; }
: id(id), feature(feature), sampler(nullptr), edges(nullptr) {}
virtual ~GraphNode();
static int id_size, int_size, weight_size;
uint64_t get_id() { return id; }
void set_id(uint64_t id) { this->id = id; }
void set_feature(std::string feature) { this->feature = feature; }
std::string get_feature() { return feature; }
virtual int get_size(bool need_feature);
virtual void build_edges(bool is_weighted);
virtual void build_sampler(std::string sample_type);
virtual void to_buffer(char *buffer, bool need_feature);
virtual void recover_from_buffer(char *buffer);
virtual void add_edge(GraphEdge *edge) { edges.push_back(edge); }
std::vector<GraphEdge *> sample_k(int k) {
std::vector<GraphEdge *> v;
if (sampler != NULL) {
auto res = sampler->sample_k(k);
for (auto x : res) {
v.push_back((GraphEdge *)x);
}
}
return v;
}
virtual void add_edge(uint64_t id, float weight) { edges->add_edge(id, weight); }
std::vector<int> sample_k(int k) { return sampler->sample_k(k); }
uint64_t get_neighbor_id(int idx){return edges->get_id(idx);}
float get_neighbor_weight(int idx){return edges->get_weight(idx);}

protected:
uint64_t id;
std::string feature;
Sampler *sampler;
std::vector<GraphEdge *> edges;
GraphEdgeBlob * edges;
};
}
}
Loading