Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
BarclayII committed Feb 14, 2023
2 parents acfdfe0 + fb8677b commit 66165d4
Show file tree
Hide file tree
Showing 27 changed files with 260 additions and 85 deletions.
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

dgl_linux_libs = 'build/libdgl.so, build/runUnitTests, python/dgl/_ffi/_cy3/core.cpython-*-x86_64-linux-gnu.so, build/tensoradapter/pytorch/*.so, build/dgl_sparse/*.so'
// Currently DGL on Windows is not working with Cython yet
dgl_win64_libs = "build\\dgl.dll, build\\runUnitTests.exe, build\\tensoradapter\\pytorch\\*.dll"
dgl_win64_libs = "build\\dgl.dll, build\\runUnitTests.exe, build\\tensoradapter\\pytorch\\*.dll, build\\dgl_sparse\\*.dll"

def init_git() {
sh 'rm -rf *'
Expand Down
11 changes: 1 addition & 10 deletions dgl_sparse/src/spspmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,7 @@ torch::Tensor _CSRMask(
auto val = TorchTensorToDGLArray(value);
auto row = TorchTensorToDGLArray(sub_mat->COOPtr()->row);
auto col = TorchTensorToDGLArray(sub_mat->COOPtr()->col);
runtime::NDArray ret;
if (val->dtype.bits == 32) {
ret = aten::CSRGetData<float>(csr, row, col, val, 0.);
} else if (val->dtype.bits == 64) {
ret = aten::CSRGetData<double>(csr, row, col, val, 0.);
} else {
TORCH_CHECK(
false, "Dtype of value for SpSpMM should be 32 or 64 bits but got: " +
std::to_string(val->dtype.bits));
}
runtime::NDArray ret = aten::CSRGetFloatingData(csr, row, col, val, 0.);
return DGLArrayToTorchTensor(ret);
}

Expand Down
30 changes: 30 additions & 0 deletions include/dgl/aten/csr.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,36 @@ runtime::NDArray CSRGetData(
CSRMatrix, runtime::NDArray rows, runtime::NDArray cols,
runtime::NDArray weights, DType filler);

/**
* @brief Get the data for each (row, col) pair, then index into the weights
* array.
*
* The operator supports matrix with duplicate entries but only one matched
* entry will be returned for each (row, col) pair. Support duplicate input
* (row, col) pairs.
*
* If some (row, col) pairs do not contain a valid non-zero elements to index
* into the weights array, DGL returns the value \a filler for that pair
* instead.
*
* @note This operator allows broadcasting (i.e, either row or col can be of
* length 1).
* @note This is the floating point number version of `CSRGetData`, which
removes the dtype template.
*
* @param mat Sparse matrix.
* @param rows Row index.
* @param cols Column index.
* @param weights The weights array.
* @param filler The value to return for row-column pairs not existent in the
* matrix.
* @return Data array. The i^th element is the data of (rows[i], cols[i])
*/
runtime::NDArray CSRGetFloatingData(
CSRMatrix, runtime::NDArray rows, runtime::NDArray cols,
runtime::NDArray weights, double filler);

/** @brief Return a transposed CSR matrix */
CSRMatrix CSRTranspose(CSRMatrix csr);

Expand Down
3 changes: 1 addition & 2 deletions python/dgl/nn/pytorch/conv/egatconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ class EGATConv(nn.Module):
f_{ij}^{\prime} &= \mathrm{LeakyReLU}\left(A [ h_{i} \| f_{ij} \| h_{j}]\right)
where :math:`f_{ij}^{\prime}` are edge features, :math:`\mathrm{A}` is weight matrix and
:math: `\vec{F}` is weight vector. After that, resulting node features
:math:`\vec{F}` is weight vector. After that, resulting node features
:math:`h_{i}^{\prime}` are updated in the same way as in regular GAT.
Parameters
Expand Down
17 changes: 13 additions & 4 deletions python/dgl/sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,25 @@
def load_dgl_sparse():
"""Load DGL C++ sparse library"""
version = torch.__version__.split("+", maxsplit=1)[0]
basename = f"libdgl_sparse_pytorch_{version}.so"

if sys.platform.startswith("linux"):
basename = f"libdgl_sparse_pytorch_{version}.so"
elif sys.platform.startswith("darwin"):
basename = f"libdgl_sparse_pytorch_{version}.dylib"
elif sys.platform.startswith("win"):
basename = f"dgl_sparse_pytorch_{version}.dll"
else:
raise NotImplementedError("Unsupported system: %s" % sys.platform)

dirname = os.path.dirname(libinfo.find_lib_path()[0])
path = os.path.join(dirname, "dgl_sparse", basename)
if not os.path.exists(path):
raise FileNotFoundError(f"Cannot find DGL C++ sparse library at {path}")

try:
torch.classes.load_library(path)
except Exception: # pylint: disable=W0703
raise ImportError("Cannot load DGL C++ sparse library")


# TODO(zhenkun): support other platforms
if sys.platform.startswith("linux"):
load_dgl_sparse()
load_dgl_sparse()
2 changes: 1 addition & 1 deletion python/dgl/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3508,7 +3508,7 @@ def radius_graph(x, r, p=2, self_loop=False,
distances = th.cdist(x, x, p=p, compute_mode=compute_mode)

if not self_loop:
distances.fill_diagonal_(r + 1e-4)
distances.fill_diagonal_(r + 1)

edges = th.nonzero(distances <= r, as_tuple=True)

Expand Down
112 changes: 112 additions & 0 deletions script/create_dev_conda_env.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#!/bin/bash

CUDA_VERSIONS="10.2,11.3,11.6,11.7"
TORCH_VERSION="1.12.0"

usage() {
cat << EOF
usage: bash $0 OPTIONS
examples:
bash $0 -c
bash $0 -g 11.7
Created a developement environment for DGL developers.
OPTIONS:
-h Show this message.
-c Create dev environment in CPU mode.
-g Create dev environment in GPU mode with specified CUDA version,
supported: $CUDA_VERSIONS.
EOF
}

validate() {
values=$(echo "$1" | tr "," "\n")
for value in $values
do
if [[ "$value" == $2 ]]; then
return 0
fi
done
return 1
}

confirm() {
echo "Continue? [yes/no]:"
read confirm
if [[ ! $confirm == "yes" ]]; then
exit 0
fi
}

# Parsing flags.
while getopts "cg:h" flag; do
if [[ $flag == "c" ]]; then
cpu=1
elif [[ $flag == "g" ]]; then
gpu=$OPTARG
elif [[ $flag == "h" ]]; then
usage
exit 0
else
usage
exit 1
fi
done

if [[ -n $gpu && $cpu -eq 1 ]]; then
echo "Only one mode can be specified."
exit 1
fi

if [[ -z $gpu && -z $cpu ]]; then
usage
exit 1
fi

# Set up CPU mode.
if [[ $cpu -eq 1 ]]; then
torchversion=${TORCH_VERSION}"+cpu"
name="dgl-dev-cpu"
fi

# Set up GPU mode.
if [[ -n $gpu ]]; then
if ! validate ${CUDA_VERSIONS} ${gpu}; then
echo "Error: Invalid CUDA version."
usage
exit 1
fi

echo "Confirm the installed CUDA version matches the specified one."
confirm

torchversion=${TORCH_VERSION}"+cu"${gpu//[-._]/}
name="dgl-dev-gpu"
fi

echo "Confirm you are excuting the script from your DGL root directory."
echo "Current working directory: $PWD"
confirm

# Prepare the conda environment yaml file.
rand=$(echo "$RANDOM" | md5sum | head -c 20)
mkdir -p /tmp/$rand
cp script/dgl_dev.yml.template /tmp/$rand/dgl_dev.yml
sed -i "s|__NAME__|$name|g" /tmp/$rand/dgl_dev.yml
sed -i "s|__TORCH_VERSION__|$torchversion|g" /tmp/$rand/dgl_dev.yml
sed -i "s|__DGL_HOME__|$PWD|g" /tmp/$rand/dgl_dev.yml

# Ask for final confirmation.
echo "--------------------------------------------------"
cat /tmp/$rand/dgl_dev.yml
echo "--------------------------------------------------"
echo "Create a conda enviroment with the config?"
confirm

# Create conda environment.
conda env create -f /tmp/$rand/dgl_dev.yml

# Clean up created tmp conda environment yaml file.
rm -rf /tmp/$rand
exit 0
29 changes: 29 additions & 0 deletions script/dgl_dev.yml.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: __NAME__
dependencies:
- python=3.7.0
- pip
- pip:
- --find-links https://download.pytorch.org/whl/torch_stable.html
- cython
- filelock
- matplotlib
- networkx
- nltk
- nose
- numpy
- ogb
- pandas
- psutil
- pyarrow
- pydantic
- pytest
- pyyaml
- rdflib
- requests[security]
- scikit-learn
- scipy
- torch==__TORCH_VERSION__
- torchmetrics
- tqdm
variables:
DGL_HOME: __DGL_HOME__
12 changes: 12 additions & 0 deletions src/array/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,18 @@ NDArray CSRGetData(
return ret;
}

runtime::NDArray CSRGetFloatingData(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
runtime::NDArray weights, double filler) {
if (weights->dtype.bits == 64) {
return CSRGetData<double>(csr, rows, cols, weights, filler);
} else {
CHECK(weights->dtype.bits == 32)
<< "CSRGetFloatingData only supports 32 or 64 bits floaring number";
return CSRGetData<float>(csr, rows, cols, weights, filler);
}
}

template NDArray CSRGetData<float>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, float filler);
template NDArray CSRGetData<double>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,10 @@
import subprocess
import sys

import pytest

# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)

EXAMPLE_ROOT = os.path.join(
os.path.dirname(os.path.relpath(__file__)),
"..",
"..",
"..",
"examples",
"sparse",
)
Expand Down
4 changes: 0 additions & 4 deletions tests/pytorch/sparse/test_diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@

from dgl.sparse import diag, DiagMatrix, identity

# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)


@pytest.mark.parametrize("val_shape", [(3,), (3, 2)])
@pytest.mark.parametrize("mat_shape", [None, (3, 5), (5, 3)])
Expand Down
4 changes: 0 additions & 4 deletions tests/pytorch/sparse/test_elementwise_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
import pytest
import torch

# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)


@pytest.mark.parametrize("val_shape", [(), (2,)])
@pytest.mark.parametrize("opname", ["add", "sub"])
Expand Down
4 changes: 0 additions & 4 deletions tests/pytorch/sparse/test_elementwise_op_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@

from dgl.sparse import from_coo, power

# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)


def all_close_sparse(A, row, col, val, shape):
rowA, colA = A.coo()
Expand Down
4 changes: 0 additions & 4 deletions tests/pytorch/sparse/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@
sparse_matrix_to_torch_sparse,
)

# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)


@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("shape", [(2, 7), (5, 2)])
Expand Down
4 changes: 0 additions & 4 deletions tests/pytorch/sparse/test_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@
import pytest
import torch

# TODO(#5013): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)

dgl_op_map = {
"sum": "sum",
"amin": "smin",
Expand Down
4 changes: 0 additions & 4 deletions tests/pytorch/sparse/test_sddmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@

from .utils import clone_detach_and_grad, rand_coo, rand_csc, rand_csr

# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)


@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("shape", [(5, 5), (5, 4)])
Expand Down
4 changes: 0 additions & 4 deletions tests/pytorch/sparse/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@

from dgl.sparse import from_coo, softmax

# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)


@pytest.mark.parametrize("val_D", [None, 2])
@pytest.mark.parametrize("csr", [True, False])
Expand Down
4 changes: 0 additions & 4 deletions tests/pytorch/sparse/test_sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@

from dgl.sparse import from_coo, from_csc, from_csr, val_like

# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)


@pytest.mark.parametrize("dense_dim", [None, 4])
@pytest.mark.parametrize("row", [(0, 0, 1, 2), (0, 1, 2, 4)])
Expand Down
4 changes: 0 additions & 4 deletions tests/pytorch/sparse/test_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@

from dgl.sparse import diag, from_coo

# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)


@pytest.mark.parametrize("val_shape", [(3,), (3, 2)])
@pytest.mark.parametrize("mat_shape", [None, (3, 5), (5, 3)])
Expand Down
1 change: 0 additions & 1 deletion tests/pytorch/sparse/test_unary_op_diag.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import sys

import backend as F
import pytest
import torch

from dgl.sparse import diag
Expand Down
Loading

0 comments on commit 66165d4

Please sign in to comment.