Skip to content

Commit

Permalink
Merge branch 'master' into structural-shifts
Browse files Browse the repository at this point in the history
  • Loading branch information
gvbazhenov committed Mar 8, 2023
2 parents 2fa397e + fe3d29a commit c7ddd2b
Show file tree
Hide file tree
Showing 6 changed files with 388 additions and 39 deletions.
6 changes: 5 additions & 1 deletion python/dgl/dataloading/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,11 @@ def _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map):
for k, v in reverse_etype_map.items()
}
exclude_eids.update(
{reverse_etype_map[k]: v for k, v in exclude_eids.items()}
{
reverse_etype_map[k]: v
for k, v in exclude_eids.items()
if k in reverse_etype_map
}
)
return exclude_eids

Expand Down
136 changes: 107 additions & 29 deletions python/dgl/nn/pytorch/sparse_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,41 @@ def all_set_embedding(self, values):
if th.distributed.is_initialized():
th.distributed.barrier()

def _all_get_tensor(self, shared_name, tensor, shape):
"""A helper function to get model-parallel tensors.
This method must and only need to be called in multi-GPU DDP training.
For now, it's only used in ``all_get_embedding`` and
``_all_get_optm_state``.
"""
# create a shared memory tensor
if self._rank == 0:
# root process creates shared memory
val = create_shared_mem_array(
shared_name,
shape,
tensor.dtype,
)
self._store.set(shared_name, shared_name)
else:
self._store.wait([shared_name])
val = get_shared_mem_array(
shared_name,
shape,
tensor.dtype,
)
# need to map indices and slice into existing tensor
idxs = self._partition.map_to_global(
F.arange(0, tensor.shape[0], ctx=F.context(tensor)),
self._rank,
).to(val.device)
val[idxs] = tensor.to(val.device)

self._store.delete_key(shared_name)
# wait for all processes to finish
th.distributed.barrier()
return val

def all_get_embedding(self):
"""Return a copy of the embedding stored in CPU memory. If this is a
multi-processing instance, the tensor will be returned in shared
Expand All @@ -367,35 +402,78 @@ def all_get_embedding(self):
# non-multiprocessing
return self._tensor.to(th.device("cpu"))
else:
# create a shared memory tensor
shared_name = self._name + "_gather"
if self._rank == 0:
# root process creates shared memory
emb = create_shared_mem_array(
shared_name,
(self._num_embeddings, self._embedding_dim),
self._tensor.dtype,
)
self._store.set(shared_name, shared_name)
else:
self._store.wait([shared_name])
emb = get_shared_mem_array(
shared_name,
(self._num_embeddings, self._embedding_dim),
self._tensor.dtype,
)
# need to map indices and slice into existing tensor
idxs = self._partition.map_to_global(
F.arange(
0, self._tensor.shape[0], ctx=F.context(self._tensor)
),
self._rank,
).to(emb.device)
emb[idxs] = self._tensor.to(emb.device)

# wait for all processes to finish
th.distributed.barrier()
return emb
return self._all_get_tensor(
f"{self._name}_gather",
self._tensor,
(self._num_embeddings, self._embedding_dim),
)
else:
# already stored in CPU memory
return self._tensor

def _all_get_optm_state(self):
"""Return a copy of the whole optimizer states stored in CPU memory.
If this is a multi-processing instance, the states will be returned in
shared memory. If the embedding is currently stored on multiple GPUs,
all processes must call this method in the same order.
NOTE: This method must be called by all processes sharing the
embedding, or it may result in a deadlock.
Returns
-------
tuple of torch.Tensor
The optimizer states stored in CPU memory.
"""
if self._partition:
if self._world_size == 0:
# non-multiprocessing
return tuple(
state.to(th.device("cpu")) for state in self._optm_state
)
else:
return tuple(
self._all_get_tensor(
f"state_gather_{self._name}_{i}",
state,
(self._num_embeddings, *state.shape[1:]),
)
for i, state in enumerate(self._optm_state)
)
else:
# already stored in CPU memory
return self._optm_state

def _all_set_optm_state(self, states):
"""Set the optimizer states of the embedding. This method must be
called by all processes sharing the embedding with identical
:attr:`states`.
NOTE: This method must be called by all processes sharing the
embedding, or it may result in a deadlock.
Parameters
----------
states : tuple of torch.Tensor
The global states to pull values from.
"""
if self._partition:
idxs = F.copy_to(
self._partition.get_local_indices(
max(self._rank, 0), ctx=F.context(self._tensor)
),
F.context(states[0]),
)
for state, new_state in zip(self._optm_state, states):
state[:] = F.copy_to(
F.gather_row(new_state, idxs), ctx=F.context(self._tensor)
)[:]
else:
# stored in CPU memory
if self._rank <= 0:
for state, new_state in zip(self._optm_state, states):
state[:] = F.copy_to(
new_state, ctx=F.context(self._tensor)
)[:]
if th.distributed.is_initialized():
th.distributed.barrier()
100 changes: 97 additions & 3 deletions python/dgl/optim/pytorch/sparse_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ def step(self):
self._comm_setup()
else:
self._shared_setup()
self.setup(self._params)
self._first_step = False

if self._comm:
self._comm_step()
else:
self._shared_step()

@abstractmethod
def setup(self, params):
"""This is function where subclasses can perform any setup they need
to. It will be called during the first step, and communicators or
Expand Down Expand Up @@ -452,6 +452,59 @@ def zero_grad(self):
"""clean grad cache"""
self._clean_grad = True

def state_dict(self, **kwargs): # pylint: disable=unused-argument
"""Return a copy of the whole optimizer states stored in CPU memory.
If this is a multi-processing instance, the states will be returned in
shared memory. If the underlying embedding is currently stored on
multiple GPUs, all processes must call this method in the same order.
NOTE: This method must be called by all processes sharing the
underlying embedding, or it may result in a deadlock.
Returns
-------
dictionary of optimizer states
The optimizer states stored in CPU memory.
"""
return {
"state": {
emb.name: emb._all_get_optm_state() for emb in self._params
},
"param_groups": self.param_groups,
}

def load_state_dict(
self, state_dict, **kwargs
): # pylint: disable=unused-argument
"""Load the optimizer states. This method must be called by all
processes sharing the underlying embedding with identical
:attr:`state_dict`.
NOTE: This method must be called by all processes sharing the
underlying embedding, or it may result in a deadlock.
Parameters
----------
state_dict : dictionary of optimizer states
The global states to pull values from.
"""
for emb in self._params:
emb._all_set_optm_state(state_dict["state"][emb.name])
self._set_param_groups(state_dict["param_groups"])

@property
@abstractmethod
def param_groups(self):
"""Emulate 'param_groups' of torch.optim.Optimizer.
Different from that, the returned 'param_groups' doesn't contain
parameters because getting the whole embedding is very expensive.
It contains other attributes, e.g., lr, eps, for debugging.
"""

@abstractmethod
def _set_param_groups(self, groups):
"""A helper method to load param_groups from saved state_dict."""


class SparseAdagrad(SparseGradOptimizer):
r"""Node embedding optimizer using the Adagrad algorithm.
Expand Down Expand Up @@ -496,6 +549,9 @@ def __init__(self, params, lr, eps=1e-10):
super(SparseAdagrad, self).__init__(params, lr)
self._eps = eps

# setup tensors for optimizer states
self.setup(self._params)

def setup(self, params):
# We need to register a state sum for each embedding in the kvstore.
for emb in params:
Expand Down Expand Up @@ -532,7 +588,7 @@ def setup(self, params):
dtype=th.float32,
device=emb.weight.device,
).zero_()
emb.set_optm_state(state)
emb.set_optm_state((state,))

def update(self, idx, grad, emb):
"""Update embeddings in a sparse manner
Expand Down Expand Up @@ -562,7 +618,7 @@ def update(self, idx, grad, emb):
grad_values = grad_values / cnt.unsqueeze(1)

grad_sum = grad_values * grad_values
state = emb.optm_state
(state,) = emb.optm_state
state_dev = state.device
state_idx = grad_indices.to(state_dev)
grad_state = state[state_idx].to(grad.device)
Expand All @@ -573,6 +629,20 @@ def update(self, idx, grad, emb):
tmp = clr * grad_values / std_values
emb.weight[state_idx] -= tmp.to(state_dev)

@property
def param_groups(self):
"""Emulate 'param_groups' of torch.optim.Optimizer.
Different from that, the returned 'param_groups' doesn't contain
parameters because getting the whole embedding is very expensive.
It contains other attributes, e.g., lr, eps, for debugging.
"""
return [{"lr": self._lr, "eps": self._eps}]

def _set_param_groups(self, groups):
"""A helper method to load param_groups from saved state_dict."""
self._lr = groups[0]["lr"]
self._eps = groups[0]["eps"]


class SparseAdam(SparseGradOptimizer):
r"""Node embedding optimizer using the Adam algorithm.
Expand Down Expand Up @@ -653,6 +723,9 @@ def __init__(
)
self._dtype = dtype

# setup tensors for optimizer states
self.setup(self._params)

def _setup_uva(self, name, mem, power):
self._is_using_uva[name] = True
mem_nd = pin_memory_inplace(mem)
Expand Down Expand Up @@ -863,3 +936,24 @@ def update(self, idx, grad, emb):
# can use it
std_event.wait()
emb.weight[state_idx] -= std_values_dst

@property
def param_groups(self):
"""Emulate 'param_groups' of torch.optim.Optimizer.
Different from that, the returned 'param_groups' doesn't contain
parameters because getting the whole embedding is very expensive.
It contains other attributes, e.g., lr, betas, eps, for debugging.
"""
return [
{
"lr": self._lr,
"betas": (self._beta1, self._beta2),
"eps": self._eps,
}
]

def _set_param_groups(self, groups):
"""A helper method to load param_groups from saved state_dict."""
self._lr = groups[0]["lr"]
self._beta1, self._beta2 = groups[0]["betas"]
self._eps = groups[0]["eps"]
35 changes: 35 additions & 0 deletions tests/python/pytorch/dataloading/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,40 @@ def test_edge_dataloader_excludes(
break


def test_edge_dataloader_exclusion_without_all_reverses():
data_dict = {
("A", "AB", "B"): (torch.tensor([0, 1]), torch.tensor([0, 1])),
("B", "BA", "A"): (torch.tensor([0, 1]), torch.tensor([0, 1])),
("B", "BC", "C"): (torch.tensor([0]), torch.tensor([0])),
("C", "CA", "A"): (torch.tensor([0, 1]), torch.tensor([0, 1])),
}
g = dgl.heterograph(data_dict=data_dict)
block_sampler = dgl.dataloading.MultiLayerNeighborSampler(
fanouts=[1], replace=True
)
block_sampler = dgl.dataloading.as_edge_prediction_sampler(
block_sampler,
exclude="reverse_types",
reverse_etypes={"AB": "BA"},
)
d = dgl.dataloading.DataLoader(
graph=g,
indices={
"AB": torch.tensor([0]),
"BC": torch.tensor([0]),
},
graph_sampler=block_sampler,
batch_size=2,
shuffle=True,
drop_last=False,
num_workers=0,
device=F.ctx(),
use_ddp=False,
)

next(iter(d))


def dummy_worker_init_fn(worker_id):
pass

Expand All @@ -647,3 +681,4 @@ def test_dataloader_worker_init_fn():
test_edge_dataloader_excludes(
"reverse_types", False, 1, dgl.dataloading.ShaDowKHopSampler([5])
)
test_edge_dataloader_exclusion_without_all_reverses()
Loading

0 comments on commit c7ddd2b

Please sign in to comment.