Skip to content

Commit

Permalink
Merge branch 'master' into labor_uvm_optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying committed Jul 3, 2023
2 parents 009be59 + fe83348 commit 22b5d74
Show file tree
Hide file tree
Showing 39 changed files with 642 additions and 496 deletions.
4 changes: 3 additions & 1 deletion graphbolt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ file(GLOB BOLT_SRC ${BOLT_DIR}/*.cc)

add_library(${LIB_GRAPHBOLT_NAME} SHARED ${BOLT_SRC} ${BOLT_HEADERS})
target_include_directories(${LIB_GRAPHBOLT_NAME} PRIVATE ${BOLT_DIR}
${BOLT_HEADERS})
${BOLT_HEADERS}
"../third_party/dmlc-core/include"
"../third_party/pcg/include")
target_link_libraries(${LIB_GRAPHBOLT_NAME} "${TORCH_LIBRARIES}")

# The Torch CMake configuration only sets up the path for the MKL library when
Expand Down
35 changes: 20 additions & 15 deletions graphbolt/src/csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,21 +340,26 @@ torch::Tensor PickByEtype(
fanouts.size(), torch::tensor({}, options));
int64_t etype_begin = offset;
int64_t etype_end = offset;
while (etype_end < offset + num_neighbors) {
int64_t etype = type_per_edge[etype_end].item<int64_t>();
int64_t fanout = fanouts[etype];
while (etype_end < offset + num_neighbors &&
type_per_edge[etype_end].item<int64_t>() == etype) {
etype_end++;
}
// Do sampling for one etype.
if (fanout != 0) {
picked_neighbors[etype] = Pick(
etype_begin, etype_end - etype_begin, fanout, replace, options,
probs_or_mask);
}
etype_begin = etype_end;
}
AT_DISPATCH_INTEGRAL_TYPES(
type_per_edge.scalar_type(), "PickByEtype", ([&] {
const scalar_t* type_per_edge_data = type_per_edge.data_ptr<scalar_t>();
const auto end = offset + num_neighbors;
while (etype_begin < end) {
scalar_t etype = type_per_edge_data[etype_begin];
int64_t fanout = fanouts[etype];
auto etype_end_it = std::upper_bound(
type_per_edge_data + etype_begin, type_per_edge_data + end,
etype);
etype_end = etype_end_it - type_per_edge_data;
// Do sampling for one etype.
if (fanout != 0) {
picked_neighbors[etype] = Pick(
etype_begin, etype_end - etype_begin, fanout, replace, options,
probs_or_mask);
}
etype_begin = etype_end;
}
}));

return torch::cat(picked_neighbors, 0);
}
Expand Down
77 changes: 77 additions & 0 deletions graphbolt/src/random.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@

/**
* Copyright (c) 2023 by Contributors
*
* @file random.h
* @brief Random Engine class.
*/
#ifndef GRAPHBOLT_RANDOM_H_
#define GRAPHBOLT_RANDOM_H_

#include <dmlc/thread_local.h>

#include <pcg_random.hpp>
#include <random>
#include <thread>

namespace graphbolt {

namespace {

// Get a unique integer ID representing this thread.
inline uint32_t GetThreadId() {
static int num_threads = 0;
static std::mutex mutex;
static thread_local int id = -1;

if (id == -1) {
std::lock_guard<std::mutex> guard(mutex);
id = num_threads;
num_threads++;
}
return id;
}

}; // namespace

/**
* @brief Thread-local Random Number Generator class.
*/
class RandomEngine {
public:
/** @brief Constructor with default seed. */
RandomEngine() {
std::random_device rd;
SetSeed(rd());
}

/** @brief Constructor with given seed. */
explicit RandomEngine(uint64_t seed, uint64_t stream = GetThreadId()) {
SetSeed(seed, stream);
}

/** @brief Get the thread-local random number generator instance. */
static RandomEngine* ThreadLocal() {
return dmlc::ThreadLocalStore<RandomEngine>::Get();
}

/** @brief Set the seed. */
void SetSeed(uint64_t seed, uint64_t stream = GetThreadId()) {
rng_.seed(seed, stream);
}

/**
* @brief Generate a uniform random integer in [low, high).
*/
template <typename T>
T RandInt(T lower, T upper) {
std::uniform_int_distribution<T> dist(lower, upper - 1);
return dist(rng_);
}

private:
pcg32 rng_;
};
} // namespace graphbolt

#endif // GRAPHBOLT_RANDOM_H_
2 changes: 1 addition & 1 deletion python/dgl/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@
partition_graph,
)
from .rpc import *
from .rpc_client import connect_to_server, shutdown_servers
from .rpc_client import connect_to_server
from .rpc_server import start_server
from .server_state import ServerState
1 change: 0 additions & 1 deletion python/dgl/distributed/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
MAX_QUEUE_SIZE = 20 * 1024 * 1024 * 1024

SERVER_EXIT = "server_exit"
SERVER_KEEP_ALIVE = "server_keep_alive"

DEFAULT_NTYPE = "_N"
DEFAULT_ETYPE = (DEFAULT_NTYPE, "_E", DEFAULT_NTYPE)
2 changes: 0 additions & 2 deletions python/dgl/distributed/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,13 @@ def initialize(
formats = os.environ.get("DGL_GRAPH_FORMAT", "csc").split(",")
formats = [f.strip() for f in formats]
rpc.reset()
keep_alive = bool(int(os.environ.get("DGL_KEEP_ALIVE", 0)))
serv = DistGraphServer(
int(os.environ.get("DGL_SERVER_ID")),
os.environ.get("DGL_IP_CONFIG"),
int(os.environ.get("DGL_NUM_SERVER")),
int(os.environ.get("DGL_NUM_CLIENT")),
os.environ.get("DGL_CONF_PATH"),
graph_format=formats,
keep_alive=keep_alive,
)
serv.start()
sys.exit()
Expand Down
5 changes: 0 additions & 5 deletions python/dgl/distributed/dist_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,6 @@ class DistGraphServer(KVServer):
Disable shared memory.
graph_format : str or list of str
The graph formats.
keep_alive : bool
Whether to keep server alive when clients exit
"""

def __init__(
Expand All @@ -343,7 +341,6 @@ def __init__(
part_config,
disable_shared_mem=False,
graph_format=("csc", "coo"),
keep_alive=False,
):
super(DistGraphServer, self).__init__(
server_id=server_id,
Expand All @@ -353,7 +350,6 @@ def __init__(
)
self.ip_config = ip_config
self.num_servers = num_servers
self.keep_alive = keep_alive
# Load graph partition data.
if self.is_backup_server():
# The backup server doesn't load the graph partition. It'll initialized afterwards.
Expand Down Expand Up @@ -457,7 +453,6 @@ def start(self):
kv_store=self,
local_g=self.client_g,
partition_book=self.gpb,
keep_alive=self.keep_alive,
)
print(
"start graph service on server {} for part {}".format(
Expand Down
3 changes: 0 additions & 3 deletions python/dgl/distributed/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,6 @@ def process_request(self, server_state):
meta = {}
kv_store = server_state.kv_store
for name, data in kv_store.data_store.items():
if server_state.keep_alive:
if name not in kv_store.orig_data:
continue
meta[name] = (
F.shape(data),
F.reverse_data_type_dict[F.dtype(data)],
Expand Down
10 changes: 3 additions & 7 deletions python/dgl/distributed/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .._ffi.function import _init_api
from .._ffi.object import ObjectBase, register_object
from ..base import DGLError
from .constants import SERVER_EXIT, SERVER_KEEP_ALIVE
from .constants import SERVER_EXIT

__all__ = [
"set_rank",
Expand Down Expand Up @@ -172,7 +172,7 @@ def finalize_receiver():
_CAPI_DGLRPCFinalizeReceiver()


def wait_for_senders(ip_addr, port, num_senders, blocking=True):
def wait_for_senders(ip_addr, port, num_senders):
"""Wait all of the senders' connections.
This api will be blocked until all the senders connect to the receiver.
Expand All @@ -185,10 +185,8 @@ def wait_for_senders(ip_addr, port, num_senders, blocking=True):
receiver's port
num_senders : int
total number of senders
blocking : bool
whether to wait blockingly
"""
_CAPI_DGLRPCWaitForSenders(ip_addr, int(port), int(num_senders), blocking)
_CAPI_DGLRPCWaitForSenders(ip_addr, int(port), int(num_senders))


def connect_receiver(ip_addr, port, recv_id, group_id=-1):
Expand Down Expand Up @@ -1258,8 +1256,6 @@ def __setstate__(self, state):

def process_request(self, server_state):
assert self.client_id == 0
if server_state.keep_alive and not self.force_shutdown_server:
return SERVER_KEEP_ALIVE
finalize_server()
return SERVER_EXIT

Expand Down
43 changes: 1 addition & 42 deletions python/dgl/distributed/rpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def connect_to_server(
for server_id in range(num_servers):
rpc.send_request(server_id, register_req)
# wait server connect back
rpc.wait_for_senders(client_ip, client_port, num_servers, blocking=True)
rpc.wait_for_senders(client_ip, client_port, num_servers)
print(
"Client [{}] waits on {}:{}".format(os.getpid(), client_ip, client_port)
)
Expand All @@ -226,44 +226,3 @@ def connect_to_server(

atexit.register(exit_client)
set_initialized(True)


def shutdown_servers(ip_config, num_servers):
"""Issue commands to remote servers to shut them down.
This function is required to be called manually only when we
have booted servers which keep alive even clients exit. In
order to shut down server elegantly, we utilize existing
client logic/code to boot a special client which does nothing
but send shut down request to servers. Once such request is
received, servers will exit from endless wait loop, release
occupied resources and end its process. Please call this function
with same arguments used in `dgl.distributed.connect_to_server`.
Parameters
----------
ip_config : str
Path of server IP configuration file.
num_servers : int
server count on each machine.
Raises
------
ConnectionError : If anything wrong with the connection.
"""
rpc.register_service(rpc.SHUT_DOWN_SERVER, rpc.ShutDownRequest, None)
rpc.register_sig_handler()
server_namebook = rpc.read_ip_config(ip_config, num_servers)
num_servers = len(server_namebook)
rpc.create_sender(MAX_QUEUE_SIZE)
# Get connected with all server nodes
for server_id, addr in server_namebook.items():
server_ip = addr[1]
server_port = addr[2]
while not rpc.connect_receiver(server_ip, server_port, server_id):
time.sleep(1)
# send ShutDownRequest to all servers
req = rpc.ShutDownRequest(0, True)
for server_id in range(num_servers):
rpc.send_request(server_id, req)
rpc.finalize_sender()
12 changes: 2 additions & 10 deletions python/dgl/distributed/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ..base import DGLError
from . import rpc
from .constants import MAX_QUEUE_SIZE, SERVER_EXIT, SERVER_KEEP_ALIVE
from .constants import MAX_QUEUE_SIZE, SERVER_EXIT


def start_server(
Expand Down Expand Up @@ -52,8 +52,6 @@ def start_server(
assert max_queue_size > 0, (
"queue_size (%d) cannot be a negative number." % max_queue_size
)
if server_state.keep_alive:
assert False, "Long live server is not supported any more."
# Register signal handler.
rpc.register_sig_handler()
# Register some basic services
Expand Down Expand Up @@ -85,7 +83,7 @@ def start_server(
print(
"Server is waiting for connections on [{}:{}]...".format(ip_addr, port)
)
rpc.wait_for_senders(ip_addr, port, num_clients, blocking=True)
rpc.wait_for_senders(ip_addr, port, num_clients)
rpc.set_num_client(num_clients)
recv_clients = {}
while True:
Expand Down Expand Up @@ -146,12 +144,6 @@ def start_server(
if res == SERVER_EXIT:
print("Server is exiting...")
return
elif res == SERVER_KEEP_ALIVE:
print(
"Server keeps alive while client group~{} is exiting...".format(
group_id
)
)
else:
raise DGLError("Unexpected response: {}".format(res))
else:
Expand Down
10 changes: 1 addition & 9 deletions python/dgl/distributed/server_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,12 @@ class ServerState:
Total number of edges
partition_book : GraphPartitionBook
Graph Partition book
keep_alive : bool
whether to keep alive which supports any number of client groups connect
"""

def __init__(self, kv_store, local_g, partition_book, keep_alive=False):
def __init__(self, kv_store, local_g, partition_book):
self._kv_store = kv_store
self._graph = local_g
self.partition_book = partition_book
self._keep_alive = keep_alive
self._roles = {}

@property
Expand All @@ -72,10 +69,5 @@ def graph(self):
def graph(self, graph):
self._graph = graph

@property
def keep_alive(self):
"""Flag of whether keep alive"""
return self._keep_alive


_init_api("dgl.distributed.server_state")
4 changes: 4 additions & 0 deletions python/dgl/graphbolt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from .itemset import *
from .minibatch_sampler import *
from .feature_store import *
from .feature_fetcher import *
from .copy_to import *
from .dataset import *
from .subgraph_sampler import *


def load_graphbolt():
Expand Down
Loading

0 comments on commit 22b5d74

Please sign in to comment.