From 6dbde611c0ddc81db76e47b7b1061ec597fa40f4 Mon Sep 17 00:00:00 2001 From: Hoppe Date: Mon, 10 Jun 2024 13:45:48 +0200 Subject: [PATCH 1/7] added tutorials on HPC for HAICON --- .pre-commit-config.yaml | 2 +- tutorials/hpc/2_basics.ipynb | 1 - tutorials/hpc/3_internals.ipynb | 1 - tutorials/hpc/4_loading_preprocessing.ipynb | 1 - tutorials/hpc/5_matrix_factorizations.ipynb | 1 - tutorials/hpc/6_clustering.ipynb | 1 - tutorials/hpc/basics/basics_broadcast.py | 19 ++ tutorials/hpc/basics/basics_datatypes.py | 27 +++ tutorials/hpc/basics/basics_distributed.py | 179 +++++++++++++++++ tutorials/hpc/basics/basics_dndarrays.py | 36 ++++ tutorials/hpc/basics/basics_gpu.py | 80 ++++++++ tutorials/hpc/basics/basics_operations.py | 41 ++++ tutorials/hpc/clustering/clustering.py | 182 ++++++++++++++++++ tutorials/hpc/clustering/iris.csv | 150 +++++++++++++++ tutorials/hpc/internals/internals.py | 130 +++++++++++++ tutorials/hpc/internals/internals_1.py | 44 +++++ tutorials/hpc/internals/internals_2.py | 89 +++++++++ .../loading_preprocessing_script.py | 98 ++++++++++ .../matrix_factorizations.py | 108 +++++++++++ .../{hpc => hpc_notebooks}/1_intro.ipynb | 0 tutorials/hpc_notebooks/2_basics.ipynb | 1 + tutorials/hpc_notebooks/3_internals.ipynb | 1 + .../4_loading_preprocessing.ipynb | 1 + .../5_matrix_factorizations.ipynb | 1 + tutorials/hpc_notebooks/6_clustering.ipynb | 1 + 25 files changed, 1189 insertions(+), 6 deletions(-) delete mode 120000 tutorials/hpc/2_basics.ipynb delete mode 120000 tutorials/hpc/3_internals.ipynb delete mode 120000 tutorials/hpc/4_loading_preprocessing.ipynb delete mode 120000 tutorials/hpc/5_matrix_factorizations.ipynb delete mode 120000 tutorials/hpc/6_clustering.ipynb create mode 100644 tutorials/hpc/basics/basics_broadcast.py create mode 100644 tutorials/hpc/basics/basics_datatypes.py create mode 100644 tutorials/hpc/basics/basics_distributed.py create mode 100644 tutorials/hpc/basics/basics_dndarrays.py create mode 100644 tutorials/hpc/basics/basics_gpu.py create mode 100644 tutorials/hpc/basics/basics_operations.py create mode 100644 tutorials/hpc/clustering/clustering.py create mode 100644 tutorials/hpc/clustering/iris.csv create mode 100644 tutorials/hpc/internals/internals.py create mode 100644 tutorials/hpc/internals/internals_1.py create mode 100644 tutorials/hpc/internals/internals_2.py create mode 100644 tutorials/hpc/loading_preprocessing/loading_preprocessing_script.py create mode 100644 tutorials/hpc/matrix_factorizations/matrix_factorizations.py rename tutorials/{hpc => hpc_notebooks}/1_intro.ipynb (100%) create mode 100644 tutorials/hpc_notebooks/2_basics.ipynb create mode 100644 tutorials/hpc_notebooks/3_internals.ipynb create mode 100644 tutorials/hpc_notebooks/4_loading_preprocessing.ipynb create mode 100644 tutorials/hpc_notebooks/5_matrix_factorizations.ipynb create mode 100644 tutorials/hpc_notebooks/6_clustering.ipynb diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2b520af675..9206271ba2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: rev: 6.3.0 # pick a git hash / tag to point to hooks: - id: pydocstyle - exclude: "tests|benchmarks|examples|scripts|setup.py" #|heat/utils/data/mnist.py|heat/utils/data/_utils.py ? + exclude: "tutorials|tests|benchmarks|examples|scripts|setup.py" #|heat/utils/data/mnist.py|heat/utils/data/_utils.py ? - repo: "https://github.com/citation-file-format/cffconvert" rev: "054bda51dbe278b3e86f27c890e3f3ac877d616c" hooks: diff --git a/tutorials/hpc/2_basics.ipynb b/tutorials/hpc/2_basics.ipynb deleted file mode 120000 index 68f73c480c..0000000000 --- a/tutorials/hpc/2_basics.ipynb +++ /dev/null @@ -1 +0,0 @@ -../local/2_basics.ipynb \ No newline at end of file diff --git a/tutorials/hpc/3_internals.ipynb b/tutorials/hpc/3_internals.ipynb deleted file mode 120000 index 4105ea65c6..0000000000 --- a/tutorials/hpc/3_internals.ipynb +++ /dev/null @@ -1 +0,0 @@ -../local/3_internals.ipynb \ No newline at end of file diff --git a/tutorials/hpc/4_loading_preprocessing.ipynb b/tutorials/hpc/4_loading_preprocessing.ipynb deleted file mode 120000 index c2010bb811..0000000000 --- a/tutorials/hpc/4_loading_preprocessing.ipynb +++ /dev/null @@ -1 +0,0 @@ -../local/4_loading_preprocessing.ipynb \ No newline at end of file diff --git a/tutorials/hpc/5_matrix_factorizations.ipynb b/tutorials/hpc/5_matrix_factorizations.ipynb deleted file mode 120000 index 41ae51349c..0000000000 --- a/tutorials/hpc/5_matrix_factorizations.ipynb +++ /dev/null @@ -1 +0,0 @@ -../local/5_matrix_factorizations.ipynb \ No newline at end of file diff --git a/tutorials/hpc/6_clustering.ipynb b/tutorials/hpc/6_clustering.ipynb deleted file mode 120000 index 8668389f7e..0000000000 --- a/tutorials/hpc/6_clustering.ipynb +++ /dev/null @@ -1 +0,0 @@ -../local/6_clustering.ipynb \ No newline at end of file diff --git a/tutorials/hpc/basics/basics_broadcast.py b/tutorials/hpc/basics/basics_broadcast.py new file mode 100644 index 0000000000..0eabb4fed5 --- /dev/null +++ b/tutorials/hpc/basics/basics_broadcast.py @@ -0,0 +1,19 @@ +import heat as ht + +# --- +# Heat implements the same broadcasting rules (implicit repetion of an operation when the rank/shape of the operands do not match) as NumPy does, e.g.: + +a = ht.arange(10) + 3 +print(f"broadcast example of adding single value 3 to [0, 1, ..., 9]: {a}") + + +a = ht.ones( + ( + 3, + 4, + ) +) +b = ht.arange(4) +print( + f"broadcasing across the first dimension of {a} with shape = (3, 4) and {b} with shape = (4): {a+b}" +) diff --git a/tutorials/hpc/basics/basics_datatypes.py b/tutorials/hpc/basics/basics_datatypes.py new file mode 100644 index 0000000000..1fe7e02548 --- /dev/null +++ b/tutorials/hpc/basics/basics_datatypes.py @@ -0,0 +1,27 @@ +import heat as ht +import numpy as np +import torch + +# ### Data Types +# +# Heat supports various data types and operations to retrieve and manipulate the type of a Heat array. However, in contrast to NumPy, Heat is limited to logical (bool) and numerical types (uint8, int16/32/64, float32/64, and complex64/128). +# +# **NOTE:** by default, Heat will allocate floating-point values in single precision, due to a much higher processing performance on GPUs. This is one of the main differences between Heat and NumPy. + +a = ht.zeros( + ( + 3, + 4, + ) +) +print(f"floating-point values in single precision is default: {a.dtype}") + +b = torch.zeros(3, 4) +print(f"like in PyTorch: {b.dtype}") + + +b = np.zeros((3, 4)) +print(f"whereas floating-point values in double precision is default in numpy: {b.dtype}") + +b = a.astype(ht.int64) +print(f"casting to int64: {b}") diff --git a/tutorials/hpc/basics/basics_distributed.py b/tutorials/hpc/basics/basics_distributed.py new file mode 100644 index 0000000000..18be27c7fa --- /dev/null +++ b/tutorials/hpc/basics/basics_distributed.py @@ -0,0 +1,179 @@ +import heat as ht + +# ### Distributed Computing +# +# Heat is also able to make use of distributed processing capabilities such as those in high-performance cluster systems. For this, Heat exploits the fact that the operations performed on a multi-dimensional array are usually identical for all data items. Hence, a data-parallel processing strategy can be chosen, where the total number of data items is equally divided among all processing nodes. An operation is then performed individually on the local data chunks and, if necessary, communicates partial results behind the scenes. A Heat array assumes the role of a virtual overlay of the local chunks and realizes and coordinates the computations - see the figure below for a visual representation of this concept. +# +# +# +# The chunks are always split along a singular dimension (i.e. 1-D domain decomposition) of the array. You can specify this in Heat by using the `split` paramter. This parameter is present in all relevant functions, such as array creation (`zeros(), ones(), ...`) or I/O (`load()`) functions. +# +# +# +# +# Examples are provided below. The result of an operation on a Heat tensor will in most cases preserve the split of the respective operands. However, in some cases the split axis might change. For example, a transpose of a Heat array will equally transpose the split axis. Furthermore, a reduction operations, e.g. `sum()` that is performed across the split axis, might remove data partitions entirely. The respective function behaviors can be found in Heat's documentation. +# +# You may also modify the data partitioning of a Heat array by using the `resplit()` function. This allows you to repartition the data as you so choose. Please note, that this should be used sparingly and for small data amounts only, as it entails significant data copying across the network. Finally, a Heat array without any split, i.e. `split=None` (default), will result in redundant copies of data on each computation node. +# +# On a technical level, Heat follows the so-called [Bulk Synchronous Parallel (BSP)](https://en.wikipedia.org/wiki/Bulk_synchronous_parallel) processing model. For the network communication, Heat utilizes the [Message Passing Interface (MPI)](https://computing.llnl.gov/tutorials/mpi/), a *de facto* standard on modern high-performance computing systems. It is also possible to use MPI on your laptop or desktop computer. Respective software packages are available for all major operating systems. In order to run a Heat script, you need to start it slightly differently than you are probably used to. This +# +# ```bash +# python ./my_script.py +# ``` +# +# becomes this instead: +# +# ```bash +# mpirun -n python ./my_script.py +# ``` +# On an HPC cluster you'll of course use SBATCH or similar. +# +# +# Let's see some examples of working with distributed Heat: + +# In the following examples, we'll recreate the array shown in the figure, a 3-dimensional DNDarray of integers ranging from 0 to 59 (5 matrices of size (4,3)). + + +dndarray = ht.arange(60).reshape(5, 4, 3) +if dndarray.comm.rank == 0: + print(f"3-dimensional DNDarray of integers ranging from 0 to 59: {dndarray}") + + +# Notice the additional metadata printed with the DNDarray. With respect to a numpy ndarray, the DNDarray has additional information on the device (in this case, the CPU) and the `split` axis. In the example above, the split axis is `None`, meaning that the DNDarray is not distributed and each MPI process has a full copy of the data. +# +# Let's experiment with a distributed DNDarray: we'll split the same DNDarray as above, but distributed along the major axis. + + +dndarray = ht.arange(60, split=0).reshape(5, 4, 3) +if dndarray.comm.rank == 0: + print(f"3-dimensional DNDarray splitted across dim 0: {dndarray}") + + +# The `split` axis is now 0, meaning that the DNDarray is distributed along the first axis. Each MPI process has a slice of the data along the first axis. In order to see the data on each process, we can print the "local array" via the `larray` attribute. + + +if dndarray.comm.rank == 0: + print(f"data on each process: {dndarray.larray}") + + +# Note that the `larray` is a `torch.Tensor` object. This is the underlying tensor that holds the data. The `dndarray` object is an MPI-aware wrapper around these process-local tensors, providing memory-distributed functionality and information. + +# The DNDarray can be distributed along any axis. Modify the `split` attribute when creating the DNDarray in the cell above, to distribute it along a different axis, and see how the `larray`s change. You'll notice that the distributed arrays are always load-balanced, meaning that the data are distributed as evenly as possible across the MPI processes. + +# The `DNDarray` object has a number of methods and attributes that are useful for distributed computing. In particular, it keeps track of its global and local (on a given process) shape through distributed operations and array manipulations. The DNDarray is also associated to a `comm` object, the MPI communicator. +# +# (In MPI, the *communicator* is a group of processes that can communicate with each other. The `comm` object is a `MPI.COMM_WORLD` communicator, which is the default communicator that includes all the processes. The `comm` object is used to perform collective operations, such as reductions, scatter, gather, and broadcast. The `comm` object is also used to perform point-to-point communication between processes.) + + +print(f"Global shape on rank {dndarray.comm.rank}: {dndarray.shape}") +print(f"Local shape on rank: {dndarray.comm.rank}: {dndarray.lshape}") + + +# You can perform a vast number of operations on DNDarrays distributed over multi-node and/or multi-GPU resources. Check out our [Numpy coverage tables](https://github.com/helmholtz-analytics/heat/blob/main/coverage_tables.md) to see what operations are already supported. +# +# The result of an operation on DNDarays will in most cases preserve the `split` or distribution axis of the respective operands. However, in some cases the split axis might change. For example, a transpose of a Heat array will equally transpose the split axis. Furthermore, a reduction operations, e.g. `sum()` that is performed across the split axis, might remove data partitions entirely. The respective function behaviors can be found in Heat's documentation. + + +# transpose +print(dndarray.T) + + +# reduction operation along the distribution axis +print(dndarray.sum(axis=0)) + + +other_dndarray = ht.arange(60, 120, split=0).reshape(5, 4, 3) # distributed reshape + +# element-wise multiplication +print(dndarray * other_dndarray) + + +# As we saw earlier, because the underlying data objects are PyTorch tensors, we can easily create DNDarrays on GPUs or move DNDarrays to GPUs. This allows us to perform distributed array operations on multi-GPU systems. +# +# So far we have demostrated small, easy-to-parallelize arithmetical operations. Let's move to linear algebra. Heat's `linalg` module supports a wide range of linear algebra operations, including matrix multiplication. Matrix multiplication is a very common operation data analysis, it is computationally intensive, and not trivial to parallelize. +# +# With Heat, you can perform matrix multiplication on distributed DNDarrays, and the operation will be parallelized across the MPI processes. Here on 4 GPUs: + + +import torch + +if torch.cuda.is_available(): + device = "gpu" +else: + device = "cpu" + +n, m = 400, 400 +x = ht.random.randn(n, m, split=0, device=device) # distributed RNG +y = ht.random.randn(m, n, split=None, device=device) +z = x @ y +print(z) + +# `ht.linalg.matmul` or `@` breaks down the matrix multiplication into a series of smaller `torch` matrix multiplications, which are then distributed across the MPI processes. This operation can be very communication-intensive on huge matrices that both require distribution, and users should choose the `split` axis carefully to minimize communication overhead. + +# You can experiment with sizes and the `split` parameter (distribution axis) for both matrices and time the result. Note that: +# - If you set **`split=None` for both matrices**, each process (in this case, each GPU) will attempt to multiply the entire matrices. Depending on the matrix sizes, the GPU memory might be insufficient. (And if you can multiply the matrices on a single GPU, it's much more efficient to stick to PyTorch's `torch.linalg.matmul` function.) +# - If **`split` is not None for both matrices**, each process will only hold a slice of the data, and will need to communicate data with other processes in order to perform the multiplication. This **introduces huge communication overhead**, but allows you to perform the multiplication on larger matrices than would fit in the memory of a single GPU. +# - If **`split` is None for one matrix and not None for the other**, the multiplication does not require communication, and the result will be distributed. If your data size allows it, you should always favor this option. +# +# Time the multiplication for different split parameters and see how the performance changes. +# +# + + +import time + +start = time.time() +z = x @ y +end = time.time() +print("runtime: ", end - start) + + +# Heat supports many linear algebra operations: +# ```bash +# >>> ht.linalg. +# ht.linalg.basics ht.linalg.hsvd_rtol( ht.linalg.projection( ht.linalg.triu( +# ht.linalg.cg( ht.linalg.inv( ht.linalg.qr( ht.linalg.vdot( +# ht.linalg.cross( ht.linalg.lanczos( ht.linalg.solver ht.linalg.vecdot( +# ht.linalg.det( ht.linalg.matmul( ht.linalg.svdtools ht.linalg.vector_norm( +# ht.linalg.dot( ht.linalg.matrix_norm( ht.linalg.trace( +# ht.linalg.hsvd( ht.linalg.norm( ht.linalg.transpose( +# ht.linalg.hsvd_rank( ht.linalg.outer( ht.linalg.tril( +# ``` +# +# and a lot more is in the works, including distributed eigendecompositions, SVD, and more. If the operation you need is not yet supported, leave us a note [here](tinyurl.com/demoissues) and we'll get back to you. + +# You can of course perform all operations on CPUs. You can leave out the `device` attribute entirely. + +# ### Interoperability +# +# We can easily create DNDarrays from PyTorch tensors and numpy ndarrays. We can also convert DNDarrays to PyTorch tensors and numpy ndarrays. This makes it easy to integrate Heat into existing PyTorch and numpy workflows. Here a basic example with xarrays: + + +import xarray as xr + +local_xr = xr.DataArray(dndarray.larray, dims=("z", "y", "x")) +# proceed with local xarray operations +print(local_xr) + + +# **NOTE:** this is not a distributed `xarray`, but local xarray objects on each rank. +# Work on [expanding xarray support](https://github.com/helmholtz-analytics/heat/pull/1183) is ongoing. +# + +# Heat will try to reuse the memory of the original array as much as possible. If you would prefer a copy with different memory, the ```copy``` keyword argument can be used when creating a DNDArray from other libraries. + + +import torch + +torch_array = torch.arange(5) +heat_array = ht.array(torch_array, copy=False) +heat_array[0] = -1 +print(torch_array) + +torch_array = torch.arange(5) +heat_array = ht.array(torch_array, copy=True) +heat_array[0] = -1 +print(torch_array) + + +# Interoperability is a key feature of Heat, and we are constantly working to increase Heat's compliance to the [Python array API standard](https://data-apis.org/array-api/latest/). As usual, please [let us know](tinyurl.com/demoissues) if you encounter any issues or have any feature requests. diff --git a/tutorials/hpc/basics/basics_dndarrays.py b/tutorials/hpc/basics/basics_dndarrays.py new file mode 100644 index 0000000000..94deb6c461 --- /dev/null +++ b/tutorials/hpc/basics/basics_dndarrays.py @@ -0,0 +1,36 @@ +import heat as ht + +# ### DNDarrays +# +# +# Similar to a NumPy `ndarray`, a Heat `dndarray` (we'll get to the `d` later) is a grid of values of a single (one particular) type. The number of dimensions is the number of axes of the array, while the shape of an array is a tuple of integers giving the number of elements of the array along each dimension. +# +# Heat emulates NumPy's API as closely as possible, allowing for the use of well-known **array creation functions**. + + +a = ht.array([1, 2, 3]) +print("array creation with values [1,2,3] with the heat array method:") +print(a) + +a = ht.ones( + ( + 4, + 5, + ) +) +print("array creation of shape = (4, 5) example with the heat ones method:") +print(a) + +a = ht.arange(10) +print("array creation with [0,1,...,9] example with the heat arange method:") +print(a) + +a = ht.full( + ( + 3, + 2, + ), + fill_value=9, +) +print("array creation with ones and shape = (3, 2) with the heat full method:") +print(a) diff --git a/tutorials/hpc/basics/basics_gpu.py b/tutorials/hpc/basics/basics_gpu.py new file mode 100644 index 0000000000..785972379f --- /dev/null +++ b/tutorials/hpc/basics/basics_gpu.py @@ -0,0 +1,80 @@ +import heat as ht + +# ## Parallel Processing +# --- +# +# Heat's actual power lies in the possibility to exploit the processing performance of modern accelerator hardware (GPUs) as well as distributed (high-performance) cluster systems. All operations executed on CPUs are, to a large extent, vectorized (AVX) and thread-parallelized (OpenMP). Heat builds on PyTorch, so it supports GPU acceleration on Nvidia and AMD GPUs. +# +# For distributed computations, your system or laptop needs to have Message Passing Interface (MPI) installed. For GPU computations, your system needs to have one or more suitable GPUs and (MPI-aware) CUDA/ROCm ecosystem. +# +# **NOTE:** The GPU examples below will only properly execute on a computer with a GPU. Make sure to either start the notebook on an appropriate machine or copy and paste the examples into a script and execute it on a suitable device. + +# ### GPUs +# +# Heat's array creation functions all support an additional parameter that which places the data on a specific device. By default, the CPU is selected, but it is also possible to directly allocate the data on a GPU. + + +import torch + +if torch.cuda.is_available(): + ht.zeros( + ( + 3, + 4, + ), + device="gpu", + ) + +# Arrays on the same device can be seamlessly used in any Heat operation. + +if torch.cuda.is_available(): + a = ht.zeros( + ( + 3, + 4, + ), + device="gpu", + ) + b = ht.ones( + ( + 3, + 4, + ), + device="gpu", + ) + print(a + b) + + +# However, performing operations on arrays with mismatching devices will purposefully result in an error (due to potentially large copy overhead). + +if torch.cuda.is_available(): + a = ht.full( + ( + 3, + 4, + ), + 4, + device="cpu", + ) + b = ht.ones( + ( + 3, + 4, + ), + device="gpu", + ) + print(a + b) + +# It is possible to explicitly move an array from one device to the other and back to avoid this error. + +if torch.cuda.is_available(): + a = ht.full( + ( + 3, + 4, + ), + 4, + device="gpu", + ) + a.cpu() + print(a + b) diff --git a/tutorials/hpc/basics/basics_operations.py b/tutorials/hpc/basics/basics_operations.py new file mode 100644 index 0000000000..ccd7705dd1 --- /dev/null +++ b/tutorials/hpc/basics/basics_operations.py @@ -0,0 +1,41 @@ +import heat as ht + +# ### Operations +# +# Heat supports many mathematical operations, ranging from simple element-wise functions, binary arithmetic operations, and linear algebra, to more powerful reductions. Operations are by default performed on the entire array or they can be performed along one or more of its dimensions when available. Most relevant for data-intensive applications is that **all Heat functionalities support memory-distributed computation and GPU acceleration**. This holds for all operations, including reductions, statistics, linear algebra, and high-level algorithms. +# +# You can try out the few simple examples below if you want, but we will skip to the [Parallel Processing](#Parallel-Processing) section to see memory-distributed operations in action. + +a = ht.full( + ( + 3, + 4, + ), + 8, +) +b = ht.ones( + ( + 3, + 4, + ) +) +c = a + b +print("matrix addition a + b:") +print(c) + + +c = ht.sub(a, b) +print("matrix substraction a - b:") +print(c) + +c = ht.arange(5).sin() +print("application of sin() elementwise:") +print(c) + +c = a.T +print("transpose operation:") +print(c) + +c = b.sum(axis=1) +print("summation of array elements:") +print(c) diff --git a/tutorials/hpc/clustering/clustering.py b/tutorials/hpc/clustering/clustering.py new file mode 100644 index 0000000000..24df3da224 --- /dev/null +++ b/tutorials/hpc/clustering/clustering.py @@ -0,0 +1,182 @@ +# Cluster Analysis +# ================ +# +# This tutorial is an interactive version of our static [clustering tutorial on ReadTheDocs](https://heat.readthedocs.io/en/stable/tutorial_clustering.html). +# +# We will demonstrate memory-distributed analysis with k-means and k-medians from the ``heat.cluster`` module. As usual, we will run the analysis on a small dataset for demonstration. We need to have an `ipcluster` running to distribute the computation. +# +# We will use matplotlib for visualization of data and results. + + +import heat as ht + + +# Spherical Clouds of Datapoints +# ------------------------------ +# For a simple demonstration of the clustering process and the differences between the algorithms, we will create an +# artificial dataset, consisting of two circularly shaped clusters positioned at $(x_1=2, y_1=2)$ and $(x_2=-2, y_2=-2)$ in 2D space. +# For each cluster we will sample 100 arbitrary points from a circle with radius of $R = 1.0$ by drawing random numbers +# for the spherical coordinates $( r\in [0,R], \phi \in [0,2\pi])$, translating these to cartesian coordinates +# and shifting them by $+2$ for cluster ``c1`` and $-2$ for cluster ``c2``. The resulting concatenated dataset ``data`` has shape +# $(200, 2)$ and is distributed among the ``p`` processes along axis 0 (sample axis). + + +num_ele = 100 +R = 1.0 + +# Create default spherical point cloud +# Sample radius between 0 and 1, and phi between 0 and 2pi +r = ht.random.rand(num_ele, split=0) * R +phi = ht.random.rand(num_ele, split=0) * 2 * ht.constants.PI + +# Transform spherical coordinates to cartesian coordinates +x = r * ht.cos(phi) +y = r * ht.sin(phi) + + +# Stack the sampled points and shift them to locations (2,2) and (-2, -2) +cluster1 = ht.stack((x + 2, y + 2), axis=1) +cluster2 = ht.stack((x - 2, y - 2), axis=1) + +data = ht.concatenate((cluster1, cluster2), axis=0) + + +# Let's plot the data for illustration. In order to do so with matplotlib, we need to unsplit the data (gather it from +# all processes) and transform it into a numpy array. Plotting can only be done on rank 0. + + +data_np = ht.resplit(data, axis=None).numpy() + + +# import matplotlib.pyplot as plt +# plt.plot(data_np[:,0], data_np[:,1], 'bo') + + +# Now we perform the clustering analysis with kmeans. We chose 'kmeans++' as an intelligent way of sampling the +# initial centroids. + +kmeans = ht.cluster.KMeans(n_clusters=2, init="kmeans++") +labels = kmeans.fit_predict(data).squeeze() +centroids = kmeans.cluster_centers_ + +# Select points assigned to clusters c1 and c2 +c1 = data[ht.where(labels == 0), :] +c2 = data[ht.where(labels == 1), :] +# After slicing, the arrays are no longer distributed evenly among the processes; we might need to balance the load +c1.balance_() # in-place operation +c2.balance_() + +print( + f"Number of points assigned to c1: {c1.shape[0]} \n" + f"Number of points assigned to c2: {c2.shape[0]} \n" + f"Centroids = {centroids}" +) + + +# Let's plot the assigned clusters and the respective centroids: + +# just for plotting: collect all the data on each process and extract the numpy arrays. This will copy data to CPU if necessary. +c1_np = c1.numpy() +c2_np = c2.numpy() + +""" +import matplotlib.pyplot as plt +# plotting on 1 process only +plt.plot(c1_np[:,0], c1_np[:,1], 'x', color='#f0781e') +plt.plot(c2_np[:,0], c2_np[:,1], 'x', color='#5a696e') +plt.plot(centroids[0,0],centroids[0,1], '^', markersize=10, markeredgecolor='black', color='#f0781e' ) +plt.plot(centroids[1,0],centroids[1,1], '^', markersize=10, markeredgecolor='black',color='#5a696e') +plt.savefig('centroids_1.png') +""" + + +# We can also cluster the data with kmedians. The respective advanced initial centroid sampling is called 'kmedians++'. + +kmedians = ht.cluster.KMedians(n_clusters=2, init="kmedians++") +labels = kmedians.fit_predict(data).squeeze() +centroids = kmedians.cluster_centers_ + +# Select points assigned to clusters c1 and c2 +c1 = data[ht.where(labels == 0), :] +c2 = data[ht.where(labels == 1), :] +# After slicing, the arrays are not distributed equally among the processes anymore; we need to balance +c1.balance_() +c2.balance_() + +print( + f"Number of points assigned to c1: {c1.shape[0]} \n" + f"Number of points assigned to c2: {c2.shape[0]} \n" + f"Centroids = {centroids}" +) + + +# Plotting the assigned clusters and the respective centroids: + +c1_np = c1.numpy() +c2_np = c2.numpy() + +""" +plt.plot(c1_np[:,0], c1_np[:,1], 'x', color='#f0781e') +plt.plot(c2_np[:,0], c2_np[:,1], 'x', color='#5a696e') +plt.plot(centroids[0,0],centroids[0,1], '^', markersize=10, markeredgecolor='black', color='#f0781e' ) +plt.plot(centroids[1,0],centroids[1,1], '^', markersize=10, markeredgecolor='black',color='#5a696e') +plt.savefig('centroids_2.png') +""" + + +# The Iris Dataset +# ------------------------------ +# The _iris_ dataset is a well known example for clustering analysis. It contains 4 measured features for samples from +# three different types of iris flowers. A subset of 150 samples is included in formats h5, csv and netcdf in the [Heat repository under 'heat/heat/datasets'](https://github.com/helmholtz-analytics/heat/tree/main/heat/datasets), and can be loaded in a distributed manner with Heat's parallel dataloader. +# +# **NOTE: you might have to change the path to the dataset in the following cell.** + +iris = ht.load("iris.csv", sep=";", split=0) + + +# Feel free to try out the other [loading options](https://heat.readthedocs.io/en/stable/autoapi/heat/core/io/index.html#heat.core.io.load) as well. +# +# Fitting the dataset with `kmeans`: + +k = 3 +kmeans = ht.cluster.KMeans(n_clusters=k, init="kmeans++") +kmeans.fit(iris) + +# Let's see what the results are. In theory, there are 50 samples of each of the 3 iris types: setosa, versicolor and virginica. We will plot the results in a 3D scatter plot, coloring the samples according to the assigned cluster. + +labels = kmeans.predict(iris).squeeze() + +# Select points assigned to clusters c1, c2 and c3 +c1 = iris[ht.where(labels == 0), :] +c2 = iris[ht.where(labels == 1), :] +c3 = iris[ht.where(labels == 2), :] +# After slicing, the arrays are not distributed equally among the processes anymore; we need to balance +# TODO is balancing really necessary? +c1.balance_() +c2.balance_() +c3.balance_() + +print( + f"Number of points assigned to c1: {c1.shape[0]} \n" + f"Number of points assigned to c2: {c2.shape[0]} \n" + f"Number of points assigned to c3: {c3.shape[0]}" +) + + +# compare Heat results with sklearn +from sklearn.cluster import KMeans +import sklearn.datasets + +k = 3 +iris_sk = sklearn.datasets.load_iris().data +kmeans_sk = KMeans(n_clusters=k, init="k-means++").fit(iris_sk) +labels_sk = kmeans_sk.predict(iris_sk) + +c1_sk = iris_sk[labels_sk == 0, :] +c2_sk = iris_sk[labels_sk == 1, :] +c3_sk = iris_sk[labels_sk == 2, :] +print( + f"Number of points assigned to c1: {c1_sk.shape[0]} \n" + f"Number of points assigned to c2: {c2_sk.shape[0]} \n" + f"Number of points assigned to c3: {c3_sk.shape[0]}" +) diff --git a/tutorials/hpc/clustering/iris.csv b/tutorials/hpc/clustering/iris.csv new file mode 100644 index 0000000000..8bc57da193 --- /dev/null +++ b/tutorials/hpc/clustering/iris.csv @@ -0,0 +1,150 @@ +5.1;3.5;1.4;0.2 +4.9;3.0;1.4;0.2 +4.7;3.2;1.3;0.2 +4.6;3.1;1.5;0.2 +5.0;3.6;1.4;0.2 +5.4;3.9;1.7;0.4 +4.6;3.4;1.4;0.3 +5.0;3.4;1.5;0.2 +4.4;2.9;1.4;0.2 +4.9;3.1;1.5;0.1 +5.4;3.7;1.5;0.2 +4.8;3.4;1.6;0.2 +4.8;3.0;1.4;0.1 +4.3;3.0;1.1;0.1 +5.8;4.0;1.2;0.2 +5.7;4.4;1.5;0.4 +5.4;3.9;1.3;0.4 +5.1;3.5;1.4;0.3 +5.7;3.8;1.7;0.3 +5.1;3.8;1.5;0.3 +5.4;3.4;1.7;0.2 +5.1;3.7;1.5;0.4 +4.6;3.6;1.0;0.2 +5.1;3.3;1.7;0.5 +4.8;3.4;1.9;0.2 +5.0;3.0;1.6;0.2 +5.0;3.4;1.6;0.4 +5.2;3.5;1.5;0.2 +5.2;3.4;1.4;0.2 +4.7;3.2;1.6;0.2 +4.8;3.1;1.6;0.2 +5.4;3.4;1.5;0.4 +5.2;4.1;1.5;0.1 +5.5;4.2;1.4;0.2 +4.9;3.1;1.5;0.1 +5.0;3.2;1.2;0.2 +5.5;3.5;1.3;0.2 +4.9;3.1;1.5;0.1 +4.4;3.0;1.3;0.2 +5.1;3.4;1.5;0.2 +5.0;3.5;1.3;0.3 +4.5;2.3;1.3;0.3 +4.4;3.2;1.3;0.2 +5.0;3.5;1.6;0.6 +5.1;3.8;1.9;0.4 +4.8;3.0;1.4;0.3 +5.1;3.8;1.6;0.2 +4.6;3.2;1.4;0.2 +5.3;3.7;1.5;0.2 +5.0;3.3;1.4;0.2 +7.0;3.2;4.7;1.4 +6.4;3.2;4.5;1.5 +6.9;3.1;4.9;1.5 +5.5;2.3;4.0;1.3 +6.5;2.8;4.6;1.5 +5.7;2.8;4.5;1.3 +6.3;3.3;4.7;1.6 +4.9;2.4;3.3;1.0 +6.6;2.9;4.6;1.3 +5.2;2.7;3.9;1.4 +5.0;2.0;3.5;1.0 +5.9;3.0;4.2;1.5 +6.0;2.2;4.0;1.0 +6.1;2.9;4.7;1.4 +5.6;2.9;3.6;1.3 +6.7;3.1;4.4;1.4 +5.6;3.0;4.5;1.5 +5.8;2.7;4.1;1.0 +6.2;2.2;4.5;1.5 +5.6;2.5;3.9;1.1 +5.9;3.2;4.8;1.8 +6.1;2.8;4.0;1.3 +6.3;2.5;4.9;1.5 +6.1;2.8;4.7;1.2 +6.4;2.9;4.3;1.3 +6.6;3.0;4.4;1.4 +6.8;2.8;4.8;1.4 +6.7;3.0;5.0;1.7 +6.0;2.9;4.5;1.5 +5.7;2.6;3.5;1.0 +5.5;2.4;3.8;1.1 +5.5;2.4;3.7;1.0 +5.8;2.7;3.9;1.2 +6.0;2.7;5.1;1.6 +5.4;3.0;4.5;1.5 +6.0;3.4;4.5;1.6 +6.7;3.1;4.7;1.5 +6.3;2.3;4.4;1.3 +5.6;3.0;4.1;1.3 +5.5;2.5;4.0;1.3 +5.5;2.6;4.4;1.2 +6.1;3.0;4.6;1.4 +5.8;2.6;4.0;1.2 +5.0;2.3;3.3;1.0 +5.6;2.7;4.2;1.3 +5.7;3.0;4.2;1.2 +5.7;2.9;4.2;1.3 +6.2;2.9;4.3;1.3 +5.1;2.5;3.0;1.1 +5.7;2.8;4.1;1.3 +6.3;3.3;6.0;2.5 +5.8;2.7;5.1;1.9 +7.1;3.0;5.9;2.1 +6.3;2.9;5.6;1.8 +6.5;3.0;5.8;2.2 +7.6;3.0;6.6;2.1 +4.9;2.5;4.5;1.7 +7.3;2.9;6.3;1.8 +6.7;2.5;5.8;1.8 +7.2;3.6;6.1;2.5 +6.5;3.2;5.1;2.0 +6.4;2.7;5.3;1.9 +6.8;3.0;5.5;2.1 +5.7;2.5;5.0;2.0 +5.8;2.8;5.1;2.4 +6.4;3.2;5.3;2.3 +6.5;3.0;5.5;1.8 +7.7;3.8;6.7;2.2 +7.7;2.6;6.9;2.3 +6.0;2.2;5.0;1.5 +6.9;3.2;5.7;2.3 +5.6;2.8;4.9;2.0 +7.7;2.8;6.7;2.0 +6.3;2.7;4.9;1.8 +6.7;3.3;5.7;2.1 +7.2;3.2;6.0;1.8 +6.2;2.8;4.8;1.8 +6.1;3.0;4.9;1.8 +6.4;2.8;5.6;2.1 +7.2;3.0;5.8;1.6 +7.4;2.8;6.1;1.9 +7.9;3.8;6.4;2.0 +6.4;2.8;5.6;2.2 +6.3;2.8;5.1;1.5 +6.1;2.6;5.6;1.4 +7.7;3.0;6.1;2.3 +6.3;3.4;5.6;2.4 +6.4;3.1;5.5;1.8 +6.0;3.0;4.8;1.8 +6.9;3.1;5.4;2.1 +6.7;3.1;5.6;2.4 +6.9;3.1;5.1;2.3 +5.8;2.7;5.1;1.9 +6.8;3.2;5.9;2.3 +6.7;3.3;5.7;2.5 +6.7;3.0;5.2;2.3 +6.3;2.5;5.0;1.9 +6.5;3.0;5.2;2.0 +6.2;3.4;5.4;2.3 +5.9;3.0;5.1;1.8 diff --git a/tutorials/hpc/internals/internals.py b/tutorials/hpc/internals/internals.py new file mode 100644 index 0000000000..08f76080e0 --- /dev/null +++ b/tutorials/hpc/internals/internals.py @@ -0,0 +1,130 @@ +import heat as ht +import torch + +# # Heat as infrastructure for MPI applications +# +# In this section, we'll go through some Heat-specific functionalities that simplify the implementation of a data-parallel application in Python. We'll demonstrate them on small arrays and 4 processes on a single cluster node, but the functionalities are indeed meant for a multi-node set up with huge arrays that cannot be processed on a single node. + + +# We already mentioned that the DNDarray object is "MPI-aware". Each DNDarray is associated to an MPI communicator, it is aware of the number of processes in the communicator, and it knows the rank of the process that owns it. +# + +a = ht.random.randn(7, 4, 3, split=0) +if a.comm.rank == 0: + print(f"a.com gets the communicator {a.comm} associated with DNDarray a") + +# MPI size = total number of processes +size = a.comm.size + +if a.comm.rank == 0: + print(f"a is distributed over {size} processes") + print(f"a is a distributed {a.ndim}-dimensional array with global shape {a.shape}") + + +# MPI rank = rank of each process +rank = a.comm.rank +# Local shape = shape of the data on each process +local_shape = a.lshape +print(f"Rank {rank} holds a slice of a with local shape {local_shape}") + + +# ### Distribution map +# +# In many occasions, when building a memory-distributed pipeline it will be convenient for each rank to have information on what ranks holds which slice of the distributed array. +# +# The `lshape_map` attribute of a DNDarray gathers (or, if possible, calculates) this info from all processes and stores it as metadata of the DNDarray. Because it is meant for internal use, it is stored in a torch tensor, not a DNDarray. +# +# The `lshape_map` tensor is a 2D tensor, where the first dimension is the number of processes and the second dimension is the number of dimensions of the array. Each row of the tensor contains the local shape of the array on a process. + + +lshape_map = a.lshape_map +if a.comm.rank == 0: + print(f"lshape_map available on any process: {lshape_map}") + +# Go back to where we created the DNDarray and and create `a` with a different split axis. See how the `lshape_map` changes. + +# ### Modifying the DNDarray distribution +# +# In a distributed pipeline, it is sometimes necessary to change the distribution of a DNDarray, when the array is not distributed in the most convenient way for the next operation / algorithm. +# +# Depending on your needs, you can choose between: +# - `DNDarray.redistribute_()`: This method keeps the original split axis, but redistributes the data of the DNDarray according to a "target map". +# - `DNDarray.resplit_()`: This method changes the split axis of the DNDarray. This is a more expensive operation, and should be used only when absolutely necessary. Depending on your needs and available resources, in some cases it might be wiser to keep a copy of the DNDarray with a different split axis. +# +# Let's see some examples. + + +# redistribute +target_map = a.lshape_map +target_map[:, a.split] = torch.tensor([1, 2, 2, 2]) +# in-place redistribution (see ht.redistribute for out-of-place) +a.redistribute_(target_map=target_map) + +# new lshape map after redistribution +a.lshape_map + +# local arrays after redistribution +a.larray + + +# resplit +a.resplit_(axis=1) + +a.lshape_map + + +# You can use the `resplit_` method (in-place), or `ht.resplit` (out-of-place) to change the distribution axis, but also to set the distribution axis to None. The latter corresponds to an MPI.Allgather operation that gathers the entire array on each process. This is useful when you've achieved a small enough data size that can be processed on a single device, and you want to avoid communication overhead. + + +# "un-split" distributed array +a.resplit_(axis=None) +# each process now holds a copy of the entire array + + +# The opposite is not true, i.e. you cannot use `resplit_` to distribute an array with split=None. In that case, you must use the `ht.array()` factory function: + + +# make `a` split again +a = ht.array(a, split=0) + + +# ### Making disjoint data into a global DNDarray +# +# Another common occurrence in a data-parallel pipeline: you have addressed the embarassingly-parallel part of your algorithm with any array framework, each process working independently from the others. You now want to perform a non-embarassingly-parallel operation on the entire dataset, with Heat as a backend. +# +# You can use the `ht.array` factory function with the `is_split` argument to create a DNDarray from a disjoint (on each MPI process) set of arrays. The `is_split` argument indicates the axis along which the disjoint data is to be "joined" into a global, distributed DNDarray. + + +# create some random local arrays on each process +import numpy as np + +local_array = np.random.rand(3, 4) + +# join them into a distributed array +a_0 = ht.array(local_array, is_split=0) +a_0.shape + + +# Change the cell above and join the arrays along a different axis. Note that the shapes of the local arrays must be consistent along the non-split axes. They can differ along the split axis. + +# The `ht.array` function takes any data object as an input that can be converted to a torch tensor. + +# Once you've made your disjoint data into a DNDarray, you can apply any Heat operation or algorithm to it and exploit the cumulative RAM of all the processes in the communicator. + +# You can access the MPI communication functionalities of the DNDarray through the `comm` attribute, i.e.: +# +# ```python +# # these are just examples, this cell won't do anything +# a.comm.Allreduce(a, b, op=MPI.SUM) +# +# a.comm.Allgather(a, b) +# a.comm.Isend(a, dest=1, tag=0) +# ``` +# +# etc. + +# In the next notebooks, we'll show you how we use Heat's distributed-array infrastructure to scale complex data analysis workflows to large datasets and high-performance computing resources. +# +# - [Data loading and preprocessing](4_loading_preprocessing.ipynb) +# - [Matrix factorization algorithms](5_matrix_factorizations.ipynb) +# - [Clustering algorithms](6_clustering.ipynb) diff --git a/tutorials/hpc/internals/internals_1.py b/tutorials/hpc/internals/internals_1.py new file mode 100644 index 0000000000..dfeed9ba74 --- /dev/null +++ b/tutorials/hpc/internals/internals_1.py @@ -0,0 +1,44 @@ +import heat as ht +import torch + +# # Heat as infrastructure for MPI applications +# +# In this section, we'll go through some Heat-specific functionalities that simplify the implementation of a data-parallel application in Python. We'll demonstrate them on small arrays and 4 processes on a single cluster node, but the functionalities are indeed meant for a multi-node set up with huge arrays that cannot be processed on a single node. + + +# We already mentioned that the DNDarray object is "MPI-aware". Each DNDarray is associated to an MPI communicator, it is aware of the number of processes in the communicator, and it knows the rank of the process that owns it. +# + +a = ht.random.randn(7, 4, 3, split=1) +if a.comm.rank == 0: + print(f"a.com gets the communicator {a.comm} associated with DNDarray a") + +# MPI size = total number of processes +size = a.comm.size + +if a.comm.rank == 0: + print(f"a is distributed over {size} processes") + print(f"a is a distributed {a.ndim}-dimensional array with global shape {a.shape}") + + +# MPI rank = rank of each process +rank = a.comm.rank +# Local shape = shape of the data on each process +local_shape = a.lshape +print(f"Rank {rank} holds a slice of a with local shape {local_shape}") + + +# ### Distribution map +# +# In many occasions, when building a memory-distributed pipeline it will be convenient for each rank to have information on what ranks holds which slice of the distributed array. +# +# The `lshape_map` attribute of a DNDarray gathers (or, if possible, calculates) this info from all processes and stores it as metadata of the DNDarray. Because it is meant for internal use, it is stored in a torch tensor, not a DNDarray. +# +# The `lshape_map` tensor is a 2D tensor, where the first dimension is the number of processes and the second dimension is the number of dimensions of the array. Each row of the tensor contains the local shape of the array on a process. + + +lshape_map = a.lshape_map +if a.comm.rank == 0: + print(f"lshape_map available on any process: {lshape_map}") + +# Go back to where we created the DNDarray and and create `a` with a different split axis. See how the `lshape_map` changes. diff --git a/tutorials/hpc/internals/internals_2.py b/tutorials/hpc/internals/internals_2.py new file mode 100644 index 0000000000..99aaf647fe --- /dev/null +++ b/tutorials/hpc/internals/internals_2.py @@ -0,0 +1,89 @@ +import heat as ht +import torch + +# ### Modifying the DNDarray distribution +# +# In a distributed pipeline, it is sometimes necessary to change the distribution of a DNDarray, when the array is not distributed in the most convenient way for the next operation / algorithm. +# +# Depending on your needs, you can choose between: +# - `DNDarray.redistribute_()`: This method keeps the original split axis, but redistributes the data of the DNDarray according to a "target map". +# - `DNDarray.resplit_()`: This method changes the split axis of the DNDarray. This is a more expensive operation, and should be used only when absolutely necessary. Depending on your needs and available resources, in some cases it might be wiser to keep a copy of the DNDarray with a different split axis. +# +# Let's see some examples. + +a = ht.random.randn(7, 4, 3, split=1) + +# redistribute +target_map = a.lshape_map +target_map[:, a.split] = torch.tensor([1, 2, 2, 2]) +# in-place redistribution (see ht.redistribute for out-of-place) +a.redistribute_(target_map=target_map) + +# new lshape map after redistribution +a.lshape_map + +# local arrays after redistribution +a.larray + + +# resplit +a.resplit_(axis=1) + +a.lshape_map + + +# You can use the `resplit_` method (in-place), or `ht.resplit` (out-of-place) to change the distribution axis, but also to set the distribution axis to None. The latter corresponds to an MPI.Allgather operation that gathers the entire array on each process. This is useful when you've achieved a small enough data size that can be processed on a single device, and you want to avoid communication overhead. + + +# "un-split" distributed array +a.resplit_(axis=None) +# each process now holds a copy of the entire array + + +# The opposite is not true, i.e. you cannot use `resplit_` to distribute an array with split=None. In that case, you must use the `ht.array()` factory function: + + +# make `a` split again +a = ht.array(a, split=0) + + +# ### Making disjoint data into a global DNDarray +# +# Another common occurrence in a data-parallel pipeline: you have addressed the embarassingly-parallel part of your algorithm with any array framework, each process working independently from the others. You now want to perform a non-embarassingly-parallel operation on the entire dataset, with Heat as a backend. +# +# You can use the `ht.array` factory function with the `is_split` argument to create a DNDarray from a disjoint (on each MPI process) set of arrays. The `is_split` argument indicates the axis along which the disjoint data is to be "joined" into a global, distributed DNDarray. + + +# create some random local arrays on each process +import numpy as np + +local_array = np.random.rand(3, 4) + +# join them into a distributed array +a_0 = ht.array(local_array, is_split=0) +a_0.shape + + +# Change the cell above and join the arrays along a different axis. Note that the shapes of the local arrays must be consistent along the non-split axes. They can differ along the split axis. + +# The `ht.array` function takes any data object as an input that can be converted to a torch tensor. + +# Once you've made your disjoint data into a DNDarray, you can apply any Heat operation or algorithm to it and exploit the cumulative RAM of all the processes in the communicator. + +# You can access the MPI communication functionalities of the DNDarray through the `comm` attribute, i.e.: +# +# ```python +# # these are just examples, this cell won't do anything +# a.comm.Allreduce(a, b, op=MPI.SUM) +# +# a.comm.Allgather(a, b) +# a.comm.Isend(a, dest=1, tag=0) +# ``` +# +# etc. + +# In the next notebooks, we'll show you how we use Heat's distributed-array infrastructure to scale complex data analysis workflows to large datasets and high-performance computing resources. +# +# - [Data loading and preprocessing](4_loading_preprocessing.ipynb) +# - [Matrix factorization algorithms](5_matrix_factorizations.ipynb) +# - [Clustering algorithms](6_clustering.ipynb) diff --git a/tutorials/hpc/loading_preprocessing/loading_preprocessing_script.py b/tutorials/hpc/loading_preprocessing/loading_preprocessing_script.py new file mode 100644 index 0000000000..3d37992c21 --- /dev/null +++ b/tutorials/hpc/loading_preprocessing/loading_preprocessing_script.py @@ -0,0 +1,98 @@ +# # Loading and Preprocessing +# +# ### Refresher +# +# Using PyTorch as compute engine and mpi4py for communication, Heat implements a number of array operations and algorithms that are optimized for memory-distributed data volumes. This allows you to tackle datasets that are too large for single-node (or worse, single-GPU) processing. +# +# As opposed to task-parallel frameworks, Heat takes a data-parallel approach, meaning that each "worker" or MPI process performs the same tasks on different slices of the data. Many operations and algorithms are not embarassingly parallel, and involve data exchange between processes. Heat operations and algorithms are designed to minimize this communication overhead, and to make it transparent to the user. +# +# In other words: +# - you don't have to worry about optimizing data chunk sizes; +# - you don't have to make sure your research problem is embarassingly parallel, or artificially make your dataset smaller so your RAM is sufficient; +# - you do have to make sure that you have sufficient **overall** RAM to run your global task (e.g. number of nodes / GPUs). + +# The following shows some I/O and preprocessing examples. We'll use small datasets here as each of us only has access to one node only. + +# ### I/O +# +# Let's start with loading a data set. Heat supports reading and writing from/into shared memory for a number of formats, including HDF5, NetCDF, and because we love scientists, csv. Check out the `ht.load` and `ht.save` functions for more details. Here we will load data in [HDF5 format](https://en.wikipedia.org/wiki/Hierarchical_Data_Format). +# +# This particular example data set (generated from all Asteroids from the [JPL Small Body Database](https://ssd.jpl.nasa.gov/sb/)) is really small, but it allows to demonstrate the basic functionality of Heat. +# + +# The above cell should return [0, 1, 2, 3]. +# +# Now let's import `heat` and load the data set. + +import heat as ht + +# X = ht.load_hdf5("../data/sbdb_asteroids.h5",dtype=ht.float64,dataset="data",split=0) + +# Some random data for small scale tests +X = ht.random.randn(1000, 3, split=0) + +# We have loaded the entire data onto 4 MPI processes, each with 12 cores. We have created `X` with `split=0`, so each process stores evenly-sized slices of the data along dimension 0. + +# ### Data exploration +# +# Let's get an idea of the size of the data. + + +# print global metadata once only +if X.comm.rank == 0: + print(f"X is a {X.ndim}-dimensional array with shape{X.shape}") + print(f"X takes up {X.nbytes/1e6} MB of memory.") + +# X is a matrix of shape *(datapoints, features)*. +# +# To get a first overview, we can print the data and determine its feature-wise mean, variance, min, max etc. These are reduction operations along the datapoints dimension, which is also the `split` dimension. You don't have to implement [`MPI.Allreduce`](https://mpitutorial.com/tutorials/mpi-reduce-and-allreduce/) operations yourself, communication is handled by Heat operations. + + +features_mean = ht.mean(X, axis=0) +features_var = ht.var(X, axis=0) +features_max = ht.max(X, axis=0) +features_min = ht.min(X, axis=0) +# ht.percentile is buggy, see #1389, we'll leave it out for now +# features_median = ht.percentile(X,50.,axis=0) + + +if ht.MPI_WORLD.rank == 0: + print(f"Mean: {features_mean}") + print(f"Var: {features_var}") + print(f"Max: {features_max}") + print(f"Min: {features_min}") + + +# Note that the `features_...` DNDarrays are no longer distributed, i.e. a copy of these results exists on each GPU, as the split dimension of the input data has been lost in the reduction operations. + +# ### Preprocessing/scaling +# +# Next, we can preprocess the data, e.g., by standardizing and/or normalizing. Heat offers several preprocessing routines for doing so, the API is similar to [`sklearn.preprocessing`](https://scikit-learn.org/stable/modules/preprocessing.html) so adapting existing code shouldn't be too complicated. +# +# Again, please let us know if you're missing any features. + + +# Standard Scaler +scaler = ht.preprocessing.StandardScaler() +X_standardized = scaler.fit_transform(X) +standardized_mean = ht.mean(X_standardized, axis=0) +standardized_var = ht.var(X_standardized, axis=0) + +if ht.MPI_WORLD.rank == 0: + print(f"Standard Scaler Mean: {standardized_mean}") + print(f"Standard Scaler Var: {standardized_var}") + +# Robust Scaler +scaler = ht.preprocessing.RobustScaler() +X_robust = scaler.fit_transform(X) +robust_mean = ht.mean(X_robust, axis=0) +robust_var = ht.var(X_robust, axis=0) + +if ht.MPI_WORLD.rank == 0: + print(f"Robust Scaler Mean: {robust_mean}") + print(f"Robust Scaler Median: {robust_var}") + + +# Within Heat, you have several options to apply memory-distributed machine learning algorithms on your data. +# +# Is the algorithm you're looking for not yet implemented? [Let us know](https://github.com/helmholtz-analytics/heat/issues/new/choose)! diff --git a/tutorials/hpc/matrix_factorizations/matrix_factorizations.py b/tutorials/hpc/matrix_factorizations/matrix_factorizations.py new file mode 100644 index 0000000000..1a77b2436e --- /dev/null +++ b/tutorials/hpc/matrix_factorizations/matrix_factorizations.py @@ -0,0 +1,108 @@ +# # Matrix factorizations +# +# ### Refresher +# +# Using PyTorch as compute engine and mpi4py for communication, Heat implements a number of array operations and algorithms that are optimized for memory-distributed data volumes. This allows you to tackle datasets that are too large for single-node (or worse, single-GPU) processing. +# +# As opposed to task-parallel frameworks, Heat takes a data-parallel approach, meaning that each "worker" or MPI process performs the same tasks on different slices of the data. Many operations and algorithms are not embarassingly parallel, and involve data exchange between processes. Heat operations and algorithms are designed to minimize this communication overhead, and to make it transparent to the user. +# +# In other words: +# - you don't have to worry about optimizing data chunk sizes; +# - you don't have to make sure your research problem is embarassingly parallel, or artificially make your dataset smaller so your RAM is sufficient; +# - you do have to make sure that you have sufficient **overall** RAM to run your global task (e.g. number of nodes / GPUs). + +# In the following, we will demonstrate the usage of Heat's truncated SVD algorithm. + +# ### SVD and its truncated counterparts in a nutshell +# +# Let $X \in \mathbb{R}^{m \times n}$ be a matrix, e.g., given by a data set consisting of $m$ data points $\in \mathbb{R}^n$ stacked together. The so-called **singular value decomposition (SVD)** of $X$ is given by +# +# $$ +# X = U \Sigma V^T +# $$ +# +# where $U \in \mathbb{R}^{m \times r_X}$ and $V \in \mathbb{R}^{n \times r_X}$ have orthonormal columns, $\Sigma = \text{diag}(\sigma_1,...,\sigma_{r_X}) \in \mathbb{R}^{r_X \times r_X}$ is a diagonal matrix containing the so-called singular values $\sigma_1 \geq \sigma_2 \geq ... \geq \sigma_{r_X} > 0$, and $r_X \leq \min(m,n)$ denotes the rank of $X$ (i.e. the dimension of the subspace of $\mathbb{R}^m$ spanned by the columns of $X$). Since $\Sigma = U^T X V$ is diagonal, one can imagine this decomposition as finding orthogonal coordinate transformations under which $X$ looks "linear". + +# ### SVD in data science +# +# In data science, SVD is more often known as **principle component analysis (PCA)**, the columns of $U$ being called the principle components of $X$. In fact, in many applications **truncated SVD/PCA** suffices: to reduce $X$ to the "essential" information, one chooses a truncation rank $0 < r \leq r_X$ and considers the truncated SVD/PCA given by +# +# $$ +# X \approx X_r := U_{[:,:r]} \Sigma_{[:r,:r]} V_{[:,:r]}^T +# $$ +# +# where we have used `numpy`-like notation for selecting only the first $r$ columns of $U$ and $V$, respectively. The rationale behind this is that if the first $r$ singular values of $X$ are much larger than the remaining ones, $X_r$ will still contain all "essential" information contained in $X$; in mathematical terms: +# +# $$ +# \lVert X_r - X \rVert_{F}^2 = \sum_{i=r+1}^{r_X} \sigma_i^2, +# $$ +# +# where $\lVert \cdot \rVert_F$ denotes the Frobenius norm. Thus, truncated SVD/PCA may be used for, e.g., +# * filtering away non-essential information in order to get a "feeling" for the main characteristics of your data set, +# * to detect linear (or "almost" linear) dependencies in your data, +# * to generate features for further processing of your data. +# +# Moreover, there is a plenty of more advanced data analytics and data-based simulation techniques, such as, e.g., Proper Orthogonal Decomposition (POD) or Dynamic Mode Decomposition (DMD), that are based on SVD/PCA. + +# ### Truncated SVD in Heat +# +# In Heat we have currently implemented an algorithm for computing an approximate truncated SVD, where truncation takes place either w.r.t. a fixed truncation-rank (`heat.linalg.hsvd_rank`) or w.r.t. a desired accuracy (`heat.linalg.hsvd_rtol`). In the latter case it can be ensured that it holds for the "reconstruction error": +# +# $$ +# \frac{\lVert X - U U^T X \rVert_F}{\lVert X \rVert_F} \overset{!}{\leq} \text{rtol}, +# $$ +# +# where $U$ denotes the approximate left-singular vectors of $X$ computed by `heat.linalg.hsvd_rtol`. +# + +# To demonstrate the usage of Heat's truncated SVD algorithm, we will load the data set from the last example and then compute its truncated SVD. As usual, first we need to gain access to the MPI environment. + + +import heat as ht + +X = ht.load_hdf5( + "/p/scratch/training2404/data/JPL_SBDB/sbdb_asteroids.h5", dataset="data", split=0 +).T + + +# Note that due to the transpose, `X` is distributed along the columns now; this is required by the hSVD-algorithm. + +# Let's first compute the truncated SVD by setting the relative tolerance. + + +# compute truncated SVD w.r.t. relative tolerance +svd_with_reltol = ht.linalg.hsvd_rtol(X, rtol=1.0e-2, compute_sv=True, silent=False) +print("relative residual:", svd_with_reltol[3], "rank: ", svd_with_reltol[0].shape[1]) + + +# Alternatively, you can compute a truncated SVD with a fixed truncation rank: + +# compute truncated SVD w.r.t. a fixed truncation rank +svd_with_rank = ht.linalg.hsvd_rank(X, maxrank=3, compute_sv=True, silent=False) +print("relative residual:", svd_with_rank[3], "rank: ", svd_with_rank[0].shape[1]) + +# Once we have computed the truncated SVD, we can use it to approximate the original data matrix `X` by the truncated matrix `X_r`. +# +# Check out the plot below to see how Heat's truncated SVD algorithm scales with the number of MPI processes and size of the dataset. + +#
+# +# +# +#
+# + +# ### Other factorizations +# +# Other common factorization algorithms are supported in Heat, such as: +# - QR decomposition (`heat.linalg.qr`), +# - Lanczos algorithm for computing the largest eigenvalues and corresponding eigenvectors (`heat.linalg.lanczos`) +# +# Check out our [`linalg` PRs](https://github.com/helmholtz-analytics/heat/pulls?q=is%3Aopen+is%3Apr+label%3Alinalg) to see what's in progress. +# + +# **References for hierarchical SVD** +# +# 1. Iwen, Ong. *A distributed and incremental SVD algorithm for agglomerative data analysis on large networks.* SIAM J. Matrix Anal. Appl., **37** (4), 2016. +# 2. Himpe, Leibner, Rave. *Hierarchical approximate proper orthogonal decomposition.* SIAM J. Sci. Comput., **4** (5), 2018. +# 3. Halko, Martinsson, Tropp. *Finding Structure with Randomness: Probabilistic Algorithms for Constructing Approximate Matrix Decompositions.* SIAM Rev. 53, **2** (2011) diff --git a/tutorials/hpc/1_intro.ipynb b/tutorials/hpc_notebooks/1_intro.ipynb similarity index 100% rename from tutorials/hpc/1_intro.ipynb rename to tutorials/hpc_notebooks/1_intro.ipynb diff --git a/tutorials/hpc_notebooks/2_basics.ipynb b/tutorials/hpc_notebooks/2_basics.ipynb new file mode 100644 index 0000000000..d82dbc7f70 --- /dev/null +++ b/tutorials/hpc_notebooks/2_basics.ipynb @@ -0,0 +1 @@ +../local/2_basics.ipynb diff --git a/tutorials/hpc_notebooks/3_internals.ipynb b/tutorials/hpc_notebooks/3_internals.ipynb new file mode 100644 index 0000000000..eb2b3a38d8 --- /dev/null +++ b/tutorials/hpc_notebooks/3_internals.ipynb @@ -0,0 +1 @@ +../local/3_internals.ipynb diff --git a/tutorials/hpc_notebooks/4_loading_preprocessing.ipynb b/tutorials/hpc_notebooks/4_loading_preprocessing.ipynb new file mode 100644 index 0000000000..622dcfbb48 --- /dev/null +++ b/tutorials/hpc_notebooks/4_loading_preprocessing.ipynb @@ -0,0 +1 @@ +../local/4_loading_preprocessing.ipynb diff --git a/tutorials/hpc_notebooks/5_matrix_factorizations.ipynb b/tutorials/hpc_notebooks/5_matrix_factorizations.ipynb new file mode 100644 index 0000000000..a0291ba9f6 --- /dev/null +++ b/tutorials/hpc_notebooks/5_matrix_factorizations.ipynb @@ -0,0 +1 @@ +../local/5_matrix_factorizations.ipynb diff --git a/tutorials/hpc_notebooks/6_clustering.ipynb b/tutorials/hpc_notebooks/6_clustering.ipynb new file mode 100644 index 0000000000..9c8a780bc0 --- /dev/null +++ b/tutorials/hpc_notebooks/6_clustering.ipynb @@ -0,0 +1 @@ +../local/6_clustering.ipynb From 1466a99a4e1dbefd5ff3ffa8dcbb3bcba0717fee Mon Sep 17 00:00:00 2001 From: Hoppe Date: Mon, 10 Jun 2024 13:55:15 +0200 Subject: [PATCH 2/7] renamed folders --- tutorials/hpc/{basics => 01_basics}/basics_broadcast.py | 0 tutorials/hpc/{basics => 01_basics}/basics_datatypes.py | 0 tutorials/hpc/{basics => 01_basics}/basics_distributed.py | 0 tutorials/hpc/{basics => 01_basics}/basics_dndarrays.py | 0 tutorials/hpc/{basics => 01_basics}/basics_gpu.py | 0 tutorials/hpc/{basics => 01_basics}/basics_operations.py | 0 tutorials/hpc/{internals => 02_internals}/internals.py | 0 tutorials/hpc/{internals => 02_internals}/internals_1.py | 0 tutorials/hpc/{internals => 02_internals}/internals_2.py | 0 .../loading_preprocessing_script.py | 0 .../matrix_factorizations.py | 0 tutorials/hpc/{clustering => 05_clustering}/clustering.py | 0 tutorials/hpc/{clustering => 05_clustering}/iris.csv | 0 13 files changed, 0 insertions(+), 0 deletions(-) rename tutorials/hpc/{basics => 01_basics}/basics_broadcast.py (100%) rename tutorials/hpc/{basics => 01_basics}/basics_datatypes.py (100%) rename tutorials/hpc/{basics => 01_basics}/basics_distributed.py (100%) rename tutorials/hpc/{basics => 01_basics}/basics_dndarrays.py (100%) rename tutorials/hpc/{basics => 01_basics}/basics_gpu.py (100%) rename tutorials/hpc/{basics => 01_basics}/basics_operations.py (100%) rename tutorials/hpc/{internals => 02_internals}/internals.py (100%) rename tutorials/hpc/{internals => 02_internals}/internals_1.py (100%) rename tutorials/hpc/{internals => 02_internals}/internals_2.py (100%) rename tutorials/hpc/{loading_preprocessing => 03_loading_preprocessing}/loading_preprocessing_script.py (100%) rename tutorials/hpc/{matrix_factorizations => 04_matrix_factorizations}/matrix_factorizations.py (100%) rename tutorials/hpc/{clustering => 05_clustering}/clustering.py (100%) rename tutorials/hpc/{clustering => 05_clustering}/iris.csv (100%) diff --git a/tutorials/hpc/basics/basics_broadcast.py b/tutorials/hpc/01_basics/basics_broadcast.py similarity index 100% rename from tutorials/hpc/basics/basics_broadcast.py rename to tutorials/hpc/01_basics/basics_broadcast.py diff --git a/tutorials/hpc/basics/basics_datatypes.py b/tutorials/hpc/01_basics/basics_datatypes.py similarity index 100% rename from tutorials/hpc/basics/basics_datatypes.py rename to tutorials/hpc/01_basics/basics_datatypes.py diff --git a/tutorials/hpc/basics/basics_distributed.py b/tutorials/hpc/01_basics/basics_distributed.py similarity index 100% rename from tutorials/hpc/basics/basics_distributed.py rename to tutorials/hpc/01_basics/basics_distributed.py diff --git a/tutorials/hpc/basics/basics_dndarrays.py b/tutorials/hpc/01_basics/basics_dndarrays.py similarity index 100% rename from tutorials/hpc/basics/basics_dndarrays.py rename to tutorials/hpc/01_basics/basics_dndarrays.py diff --git a/tutorials/hpc/basics/basics_gpu.py b/tutorials/hpc/01_basics/basics_gpu.py similarity index 100% rename from tutorials/hpc/basics/basics_gpu.py rename to tutorials/hpc/01_basics/basics_gpu.py diff --git a/tutorials/hpc/basics/basics_operations.py b/tutorials/hpc/01_basics/basics_operations.py similarity index 100% rename from tutorials/hpc/basics/basics_operations.py rename to tutorials/hpc/01_basics/basics_operations.py diff --git a/tutorials/hpc/internals/internals.py b/tutorials/hpc/02_internals/internals.py similarity index 100% rename from tutorials/hpc/internals/internals.py rename to tutorials/hpc/02_internals/internals.py diff --git a/tutorials/hpc/internals/internals_1.py b/tutorials/hpc/02_internals/internals_1.py similarity index 100% rename from tutorials/hpc/internals/internals_1.py rename to tutorials/hpc/02_internals/internals_1.py diff --git a/tutorials/hpc/internals/internals_2.py b/tutorials/hpc/02_internals/internals_2.py similarity index 100% rename from tutorials/hpc/internals/internals_2.py rename to tutorials/hpc/02_internals/internals_2.py diff --git a/tutorials/hpc/loading_preprocessing/loading_preprocessing_script.py b/tutorials/hpc/03_loading_preprocessing/loading_preprocessing_script.py similarity index 100% rename from tutorials/hpc/loading_preprocessing/loading_preprocessing_script.py rename to tutorials/hpc/03_loading_preprocessing/loading_preprocessing_script.py diff --git a/tutorials/hpc/matrix_factorizations/matrix_factorizations.py b/tutorials/hpc/04_matrix_factorizations/matrix_factorizations.py similarity index 100% rename from tutorials/hpc/matrix_factorizations/matrix_factorizations.py rename to tutorials/hpc/04_matrix_factorizations/matrix_factorizations.py diff --git a/tutorials/hpc/clustering/clustering.py b/tutorials/hpc/05_clustering/clustering.py similarity index 100% rename from tutorials/hpc/clustering/clustering.py rename to tutorials/hpc/05_clustering/clustering.py diff --git a/tutorials/hpc/clustering/iris.csv b/tutorials/hpc/05_clustering/iris.csv similarity index 100% rename from tutorials/hpc/clustering/iris.csv rename to tutorials/hpc/05_clustering/iris.csv From dfb898dbd297170c4320679d1fa8a99077b509ac Mon Sep 17 00:00:00 2001 From: Hoppe Date: Mon, 10 Jun 2024 14:59:17 +0200 Subject: [PATCH 3/7] worked on the tutorial further... --- ...cs_dndarrays.py => 01_basics_dndarrays.py} | 0 ...cs_datatypes.py => 02_basics_datatypes.py} | 0 ..._operations.py => 03_basics_operations.py} | 0 tutorials/hpc/01_basics/04_basics_indexing.py | 13 ++ ...cs_broadcast.py => 05_basics_broadcast.py} | 0 .../{basics_gpu.py => 06_basics_gpu.py} | 3 +- .../hpc/01_basics/07_basics_distributed.py | 70 +++++++ .../08_basics_distributed_operations.py | 24 +++ .../01_basics/09_basics_distributed_matmul.py | 55 ++++++ .../hpc/01_basics/10_interoperability.py | 26 +++ .../11_internals_1.py} | 2 +- .../12_internals_2.py} | 20 +- tutorials/hpc/01_basics/basics_distributed.py | 179 ------------------ tutorials/hpc/02_internals/internals.py | 130 ------------- .../hpc/02_loading_preprocessing/01_IO.py | 32 ++++ .../02_preprocessing.py} | 31 +-- .../iris.csv | 0 17 files changed, 224 insertions(+), 361 deletions(-) rename tutorials/hpc/01_basics/{basics_dndarrays.py => 01_basics_dndarrays.py} (100%) rename tutorials/hpc/01_basics/{basics_datatypes.py => 02_basics_datatypes.py} (100%) rename tutorials/hpc/01_basics/{basics_operations.py => 03_basics_operations.py} (100%) create mode 100644 tutorials/hpc/01_basics/04_basics_indexing.py rename tutorials/hpc/01_basics/{basics_broadcast.py => 05_basics_broadcast.py} (100%) rename tutorials/hpc/01_basics/{basics_gpu.py => 06_basics_gpu.py} (99%) create mode 100644 tutorials/hpc/01_basics/07_basics_distributed.py create mode 100644 tutorials/hpc/01_basics/08_basics_distributed_operations.py create mode 100644 tutorials/hpc/01_basics/09_basics_distributed_matmul.py create mode 100644 tutorials/hpc/01_basics/10_interoperability.py rename tutorials/hpc/{02_internals/internals_1.py => 01_basics/11_internals_1.py} (96%) rename tutorials/hpc/{02_internals/internals_2.py => 01_basics/12_internals_2.py} (82%) delete mode 100644 tutorials/hpc/01_basics/basics_distributed.py delete mode 100644 tutorials/hpc/02_internals/internals.py create mode 100644 tutorials/hpc/02_loading_preprocessing/01_IO.py rename tutorials/hpc/{03_loading_preprocessing/loading_preprocessing_script.py => 02_loading_preprocessing/02_preprocessing.py} (57%) rename tutorials/hpc/{05_clustering => 02_loading_preprocessing}/iris.csv (100%) diff --git a/tutorials/hpc/01_basics/basics_dndarrays.py b/tutorials/hpc/01_basics/01_basics_dndarrays.py similarity index 100% rename from tutorials/hpc/01_basics/basics_dndarrays.py rename to tutorials/hpc/01_basics/01_basics_dndarrays.py diff --git a/tutorials/hpc/01_basics/basics_datatypes.py b/tutorials/hpc/01_basics/02_basics_datatypes.py similarity index 100% rename from tutorials/hpc/01_basics/basics_datatypes.py rename to tutorials/hpc/01_basics/02_basics_datatypes.py diff --git a/tutorials/hpc/01_basics/basics_operations.py b/tutorials/hpc/01_basics/03_basics_operations.py similarity index 100% rename from tutorials/hpc/01_basics/basics_operations.py rename to tutorials/hpc/01_basics/03_basics_operations.py diff --git a/tutorials/hpc/01_basics/04_basics_indexing.py b/tutorials/hpc/01_basics/04_basics_indexing.py new file mode 100644 index 0000000000..0949a21f09 --- /dev/null +++ b/tutorials/hpc/01_basics/04_basics_indexing.py @@ -0,0 +1,13 @@ +import heat as ht + +# ## Indexing + +# Heat allows the indexing of arrays, and thereby, the extraction of a partial view of the elements in an array. It is possible to obtain single values as well as entire chunks, i.e. slices. + +a = ht.arange(10) + +print(a[3]) +print(a[1:7]) +print(a[::2]) + +# **NOTE:** Indexing in Heat is undergoing a [major overhaul](https://github.com/helmholtz-analytics/heat/pull/938), to increase interoperability with NumPy/PyTorch indexing, and to provide a fully distributed item setting functionality. Stay tuned for this feature in the next release. diff --git a/tutorials/hpc/01_basics/basics_broadcast.py b/tutorials/hpc/01_basics/05_basics_broadcast.py similarity index 100% rename from tutorials/hpc/01_basics/basics_broadcast.py rename to tutorials/hpc/01_basics/05_basics_broadcast.py diff --git a/tutorials/hpc/01_basics/basics_gpu.py b/tutorials/hpc/01_basics/06_basics_gpu.py similarity index 99% rename from tutorials/hpc/01_basics/basics_gpu.py rename to tutorials/hpc/01_basics/06_basics_gpu.py index 785972379f..2b19ed584c 100644 --- a/tutorials/hpc/01_basics/basics_gpu.py +++ b/tutorials/hpc/01_basics/06_basics_gpu.py @@ -1,4 +1,5 @@ import heat as ht +import torch # ## Parallel Processing # --- @@ -14,8 +15,6 @@ # Heat's array creation functions all support an additional parameter that which places the data on a specific device. By default, the CPU is selected, but it is also possible to directly allocate the data on a GPU. -import torch - if torch.cuda.is_available(): ht.zeros( ( diff --git a/tutorials/hpc/01_basics/07_basics_distributed.py b/tutorials/hpc/01_basics/07_basics_distributed.py new file mode 100644 index 0000000000..026ef9f460 --- /dev/null +++ b/tutorials/hpc/01_basics/07_basics_distributed.py @@ -0,0 +1,70 @@ +import heat as ht + +# ### Distributed Computing +# +# Heat is also able to make use of distributed processing capabilities such as those in high-performance cluster systems. For this, Heat exploits the fact that the operations performed on a multi-dimensional array are usually identical for all data items. Hence, a data-parallel processing strategy can be chosen, where the total number of data items is equally divided among all processing nodes. An operation is then performed individually on the local data chunks and, if necessary, communicates partial results behind the scenes. A Heat array assumes the role of a virtual overlay of the local chunks and realizes and coordinates the computations - see the figure below for a visual representation of this concept. +# +# +# +# The chunks are always split along a singular dimension (i.e. 1-D domain decomposition) of the array. You can specify this in Heat by using the `split` paramter. This parameter is present in all relevant functions, such as array creation (`zeros(), ones(), ...`) or I/O (`load()`) functions. +# +# +# +# +# Examples are provided below. The result of an operation on a Heat tensor will in most cases preserve the split of the respective operands. However, in some cases the split axis might change. For example, a transpose of a Heat array will equally transpose the split axis. Furthermore, a reduction operations, e.g. `sum()` that is performed across the split axis, might remove data partitions entirely. The respective function behaviors can be found in Heat's documentation. +# +# You may also modify the data partitioning of a Heat array by using the `resplit()` function. This allows you to repartition the data as you so choose. Please note, that this should be used sparingly and for small data amounts only, as it entails significant data copying across the network. Finally, a Heat array without any split, i.e. `split=None` (default), will result in redundant copies of data on each computation node. +# +# On a technical level, Heat follows the so-called [Bulk Synchronous Parallel (BSP)](https://en.wikipedia.org/wiki/Bulk_synchronous_parallel) processing model. For the network communication, Heat utilizes the [Message Passing Interface (MPI)](https://computing.llnl.gov/tutorials/mpi/), a *de facto* standard on modern high-performance computing systems. It is also possible to use MPI on your laptop or desktop computer. Respective software packages are available for all major operating systems. In order to run a Heat script, you need to start it slightly differently than you are probably used to. This +# +# ```bash +# python ./my_script.py +# ``` +# +# becomes this instead: +# +# ```bash +# mpirun -n python ./my_script.py +# ``` +# On an HPC cluster you'll of course use SBATCH or similar. +# +# +# Let's see some examples of working with distributed Heat: + +# In the following examples, we'll recreate the array shown in the figure, a 3-dimensional DNDarray of integers ranging from 0 to 59 (5 matrices of size (4,3)). + + +dndarray = ht.arange(60).reshape(5, 4, 3) +if dndarray.comm.rank == 0: + print("3-dimensional DNDarray of integers ranging from 0 to 59:") +print(dndarray) + + +# Notice the additional metadata printed with the DNDarray. With respect to a numpy ndarray, the DNDarray has additional information on the device (in this case, the CPU) and the `split` axis. In the example above, the split axis is `None`, meaning that the DNDarray is not distributed and each MPI process has a full copy of the data. +# +# Let's experiment with a distributed DNDarray: we'll split the same DNDarray as above, but distributed along the major axis. + + +dndarray = ht.arange(60, split=0).reshape(5, 4, 3) +if dndarray.comm.rank == 0: + print("3-dimensional DNDarray of integers ranging from 0 to 59:") +print(dndarray) + + +# The `split` axis is now 0, meaning that the DNDarray is distributed along the first axis. Each MPI process has a slice of the data along the first axis. In order to see the data on each process, we can print the "local array" via the `larray` attribute. + +print(f"data on process no {dndarray.comm.rank}: {dndarray.larray}") + + +# Note that the `larray` is a `torch.Tensor` object. This is the underlying tensor that holds the data. The `dndarray` object is an MPI-aware wrapper around these process-local tensors, providing memory-distributed functionality and information. + +# The DNDarray can be distributed along any axis. Modify the `split` attribute when creating the DNDarray in the cell above, to distribute it along a different axis, and see how the `larray`s change. You'll notice that the distributed arrays are always load-balanced, meaning that the data are distributed as evenly as possible across the MPI processes. + +# The `DNDarray` object has a number of methods and attributes that are useful for distributed computing. In particular, it keeps track of its global and local (on a given process) shape through distributed operations and array manipulations. The DNDarray is also associated to a `comm` object, the MPI communicator. +# +# (In MPI, the *communicator* is a group of processes that can communicate with each other. The `comm` object is a `MPI.COMM_WORLD` communicator, which is the default communicator that includes all the processes. The `comm` object is used to perform collective operations, such as reductions, scatter, gather, and broadcast. The `comm` object is also used to perform point-to-point communication between processes.) + + +print(f"Global shape on rank {dndarray.comm.rank}: {dndarray.shape}") +print(f"Local shape on rank: {dndarray.comm.rank}: {dndarray.lshape}") +print(f"Local device on rank: {dndarray.comm.rank}: {dndarray.device}") diff --git a/tutorials/hpc/01_basics/08_basics_distributed_operations.py b/tutorials/hpc/01_basics/08_basics_distributed_operations.py new file mode 100644 index 0000000000..a8bf106585 --- /dev/null +++ b/tutorials/hpc/01_basics/08_basics_distributed_operations.py @@ -0,0 +1,24 @@ +import heat as ht + +dndarray = ht.arange(60, split=0).reshape(5, 4, 3) + +# You can perform a vast number of operations on DNDarrays distributed over multi-node and/or multi-GPU resources. Check out our [Numpy coverage tables](https://github.com/helmholtz-analytics/heat/blob/main/coverage_tables.md) to see what operations are already supported. +# +# The result of an operation on DNDarays will in most cases preserve the `split` or distribution axis of the respective operands. However, in some cases the split axis might change. For example, a transpose of a Heat array will equally transpose the split axis. Furthermore, a reduction operations, e.g. `sum()` that is performed across the split axis, might remove data partitions entirely. The respective function behaviors can be found in Heat's documentation. + + +# transpose +print(dndarray.T) + + +# reduction operation along the distribution axis +print(dndarray.sum(axis=0)) + +# min / max etc. +print(ht.sin(dndarray).min(axis=0)) + + +other_dndarray = ht.arange(60, 120, split=0).reshape(5, 4, 3) # distributed reshape + +# element-wise multiplication +print(dndarray * other_dndarray) diff --git a/tutorials/hpc/01_basics/09_basics_distributed_matmul.py b/tutorials/hpc/01_basics/09_basics_distributed_matmul.py new file mode 100644 index 0000000000..d15ea26eb8 --- /dev/null +++ b/tutorials/hpc/01_basics/09_basics_distributed_matmul.py @@ -0,0 +1,55 @@ +# As we saw earlier, because the underlying data objects are PyTorch tensors, we can easily create DNDarrays on GPUs or move DNDarrays to GPUs. This allows us to perform distributed array operations on multi-GPU systems. +# +# So far we have demostrated small, easy-to-parallelize arithmetical operations. Let's move to linear algebra. Heat's `linalg` module supports a wide range of linear algebra operations, including matrix multiplication. Matrix multiplication is a very common operation data analysis, it is computationally intensive, and not trivial to parallelize. +# +# With Heat, you can perform matrix multiplication on distributed DNDarrays, and the operation will be parallelized across the MPI processes. Here on 4 GPUs: + +import heat as ht +import torch + +if torch.cuda.is_available(): + device = "gpu" +else: + device = "cpu" + +n, m = 400, 400 +x = ht.random.randn(n, m, split=0, device=device) # distributed RNG +y = ht.random.randn(m, n, split=None, device=device) +z = x @ y +print(z) + +# `ht.linalg.matmul` or `@` breaks down the matrix multiplication into a series of smaller `torch` matrix multiplications, which are then distributed across the MPI processes. This operation can be very communication-intensive on huge matrices that both require distribution, and users should choose the `split` axis carefully to minimize communication overhead. + +# You can experiment with sizes and the `split` parameter (distribution axis) for both matrices and time the result. Note that: +# - If you set **`split=None` for both matrices**, each process (in this case, each GPU) will attempt to multiply the entire matrices. Depending on the matrix sizes, the GPU memory might be insufficient. (And if you can multiply the matrices on a single GPU, it's much more efficient to stick to PyTorch's `torch.linalg.matmul` function.) +# - If **`split` is not None for both matrices**, each process will only hold a slice of the data, and will need to communicate data with other processes in order to perform the multiplication. This **introduces huge communication overhead**, but allows you to perform the multiplication on larger matrices than would fit in the memory of a single GPU. +# - If **`split` is None for one matrix and not None for the other**, the multiplication does not require communication, and the result will be distributed. If your data size allows it, you should always favor this option. +# +# Time the multiplication for different split parameters and see how the performance changes. +# +# + + +import time + +start = time.time() +z = x @ y +end = time.time() +print("runtime: ", end - start) + + +# Heat supports many linear algebra operations: +# ```bash +# >>> ht.linalg. +# ht.linalg.basics ht.linalg.hsvd_rtol( ht.linalg.projection( ht.linalg.triu( +# ht.linalg.cg( ht.linalg.inv( ht.linalg.qr( ht.linalg.vdot( +# ht.linalg.cross( ht.linalg.lanczos( ht.linalg.solver ht.linalg.vecdot( +# ht.linalg.det( ht.linalg.matmul( ht.linalg.svdtools ht.linalg.vector_norm( +# ht.linalg.dot( ht.linalg.matrix_norm( ht.linalg.trace( +# ht.linalg.hsvd( ht.linalg.norm( ht.linalg.transpose( +# ht.linalg.hsvd_rank( ht.linalg.outer( ht.linalg.tril( +# ``` +# +# and a lot more is in the works, including distributed eigendecompositions, SVD, and more. If the operation you need is not yet supported, leave us a note [here](tinyurl.com/demoissues) and we'll get back to you. + +# You can of course perform all operations on CPUs. You can leave out the `device` attribute entirely. diff --git a/tutorials/hpc/01_basics/10_interoperability.py b/tutorials/hpc/01_basics/10_interoperability.py new file mode 100644 index 0000000000..f3ec217425 --- /dev/null +++ b/tutorials/hpc/01_basics/10_interoperability.py @@ -0,0 +1,26 @@ +# ### Interoperability +# +# We can easily create DNDarrays from PyTorch tensors and numpy ndarrays. We can also convert DNDarrays to PyTorch tensors and numpy ndarrays. This makes it easy to integrate Heat into existing PyTorch and numpy workflows. +# + +# Heat will try to reuse the memory of the original array as much as possible. If you would prefer a copy with different memory, the ```copy``` keyword argument can be used when creating a DNDArray from other libraries. + +import heat as ht +import torch +import numpy as np + +torch_array = torch.arange(ht.MPI_WORLD.rank, ht.MPI_WORLD.rank + 5) +heat_array = ht.array(torch_array, copy=False, is_split=0) +heat_array[0] = -1 +print(torch_array) + +torch_array = torch.arange(ht.MPI_WORLD.rank, ht.MPI_WORLD.rank + 5) +heat_array = ht.array(torch_array, copy=True, is_split=0) +heat_array[0] = -1 +print(torch_array) + +np_array = heat_array.numpy() +print(np_array) + + +# Interoperability is a key feature of Heat, and we are constantly working to increase Heat's compliance to the [Python array API standard](https://data-apis.org/array-api/latest/). As usual, please [let us know](tinyurl.com/demoissues) if you encounter any issues or have any feature requests. diff --git a/tutorials/hpc/02_internals/internals_1.py b/tutorials/hpc/01_basics/11_internals_1.py similarity index 96% rename from tutorials/hpc/02_internals/internals_1.py rename to tutorials/hpc/01_basics/11_internals_1.py index dfeed9ba74..d8c1dae30d 100644 --- a/tutorials/hpc/02_internals/internals_1.py +++ b/tutorials/hpc/01_basics/11_internals_1.py @@ -11,7 +11,7 @@ a = ht.random.randn(7, 4, 3, split=1) if a.comm.rank == 0: - print(f"a.com gets the communicator {a.comm} associated with DNDarray a") + print(f"a.comm gets the communicator {a.comm} associated with DNDarray a") # MPI size = total number of processes size = a.comm.size diff --git a/tutorials/hpc/02_internals/internals_2.py b/tutorials/hpc/01_basics/12_internals_2.py similarity index 82% rename from tutorials/hpc/02_internals/internals_2.py rename to tutorials/hpc/01_basics/12_internals_2.py index 99aaf647fe..94d71a445d 100644 --- a/tutorials/hpc/02_internals/internals_2.py +++ b/tutorials/hpc/01_basics/12_internals_2.py @@ -61,7 +61,7 @@ # join them into a distributed array a_0 = ht.array(local_array, is_split=0) -a_0.shape +print(a_0.shape) # Change the cell above and join the arrays along a different axis. Note that the shapes of the local arrays must be consistent along the non-split axes. They can differ along the split axis. @@ -69,21 +69,3 @@ # The `ht.array` function takes any data object as an input that can be converted to a torch tensor. # Once you've made your disjoint data into a DNDarray, you can apply any Heat operation or algorithm to it and exploit the cumulative RAM of all the processes in the communicator. - -# You can access the MPI communication functionalities of the DNDarray through the `comm` attribute, i.e.: -# -# ```python -# # these are just examples, this cell won't do anything -# a.comm.Allreduce(a, b, op=MPI.SUM) -# -# a.comm.Allgather(a, b) -# a.comm.Isend(a, dest=1, tag=0) -# ``` -# -# etc. - -# In the next notebooks, we'll show you how we use Heat's distributed-array infrastructure to scale complex data analysis workflows to large datasets and high-performance computing resources. -# -# - [Data loading and preprocessing](4_loading_preprocessing.ipynb) -# - [Matrix factorization algorithms](5_matrix_factorizations.ipynb) -# - [Clustering algorithms](6_clustering.ipynb) diff --git a/tutorials/hpc/01_basics/basics_distributed.py b/tutorials/hpc/01_basics/basics_distributed.py deleted file mode 100644 index 18be27c7fa..0000000000 --- a/tutorials/hpc/01_basics/basics_distributed.py +++ /dev/null @@ -1,179 +0,0 @@ -import heat as ht - -# ### Distributed Computing -# -# Heat is also able to make use of distributed processing capabilities such as those in high-performance cluster systems. For this, Heat exploits the fact that the operations performed on a multi-dimensional array are usually identical for all data items. Hence, a data-parallel processing strategy can be chosen, where the total number of data items is equally divided among all processing nodes. An operation is then performed individually on the local data chunks and, if necessary, communicates partial results behind the scenes. A Heat array assumes the role of a virtual overlay of the local chunks and realizes and coordinates the computations - see the figure below for a visual representation of this concept. -# -# -# -# The chunks are always split along a singular dimension (i.e. 1-D domain decomposition) of the array. You can specify this in Heat by using the `split` paramter. This parameter is present in all relevant functions, such as array creation (`zeros(), ones(), ...`) or I/O (`load()`) functions. -# -# -# -# -# Examples are provided below. The result of an operation on a Heat tensor will in most cases preserve the split of the respective operands. However, in some cases the split axis might change. For example, a transpose of a Heat array will equally transpose the split axis. Furthermore, a reduction operations, e.g. `sum()` that is performed across the split axis, might remove data partitions entirely. The respective function behaviors can be found in Heat's documentation. -# -# You may also modify the data partitioning of a Heat array by using the `resplit()` function. This allows you to repartition the data as you so choose. Please note, that this should be used sparingly and for small data amounts only, as it entails significant data copying across the network. Finally, a Heat array without any split, i.e. `split=None` (default), will result in redundant copies of data on each computation node. -# -# On a technical level, Heat follows the so-called [Bulk Synchronous Parallel (BSP)](https://en.wikipedia.org/wiki/Bulk_synchronous_parallel) processing model. For the network communication, Heat utilizes the [Message Passing Interface (MPI)](https://computing.llnl.gov/tutorials/mpi/), a *de facto* standard on modern high-performance computing systems. It is also possible to use MPI on your laptop or desktop computer. Respective software packages are available for all major operating systems. In order to run a Heat script, you need to start it slightly differently than you are probably used to. This -# -# ```bash -# python ./my_script.py -# ``` -# -# becomes this instead: -# -# ```bash -# mpirun -n python ./my_script.py -# ``` -# On an HPC cluster you'll of course use SBATCH or similar. -# -# -# Let's see some examples of working with distributed Heat: - -# In the following examples, we'll recreate the array shown in the figure, a 3-dimensional DNDarray of integers ranging from 0 to 59 (5 matrices of size (4,3)). - - -dndarray = ht.arange(60).reshape(5, 4, 3) -if dndarray.comm.rank == 0: - print(f"3-dimensional DNDarray of integers ranging from 0 to 59: {dndarray}") - - -# Notice the additional metadata printed with the DNDarray. With respect to a numpy ndarray, the DNDarray has additional information on the device (in this case, the CPU) and the `split` axis. In the example above, the split axis is `None`, meaning that the DNDarray is not distributed and each MPI process has a full copy of the data. -# -# Let's experiment with a distributed DNDarray: we'll split the same DNDarray as above, but distributed along the major axis. - - -dndarray = ht.arange(60, split=0).reshape(5, 4, 3) -if dndarray.comm.rank == 0: - print(f"3-dimensional DNDarray splitted across dim 0: {dndarray}") - - -# The `split` axis is now 0, meaning that the DNDarray is distributed along the first axis. Each MPI process has a slice of the data along the first axis. In order to see the data on each process, we can print the "local array" via the `larray` attribute. - - -if dndarray.comm.rank == 0: - print(f"data on each process: {dndarray.larray}") - - -# Note that the `larray` is a `torch.Tensor` object. This is the underlying tensor that holds the data. The `dndarray` object is an MPI-aware wrapper around these process-local tensors, providing memory-distributed functionality and information. - -# The DNDarray can be distributed along any axis. Modify the `split` attribute when creating the DNDarray in the cell above, to distribute it along a different axis, and see how the `larray`s change. You'll notice that the distributed arrays are always load-balanced, meaning that the data are distributed as evenly as possible across the MPI processes. - -# The `DNDarray` object has a number of methods and attributes that are useful for distributed computing. In particular, it keeps track of its global and local (on a given process) shape through distributed operations and array manipulations. The DNDarray is also associated to a `comm` object, the MPI communicator. -# -# (In MPI, the *communicator* is a group of processes that can communicate with each other. The `comm` object is a `MPI.COMM_WORLD` communicator, which is the default communicator that includes all the processes. The `comm` object is used to perform collective operations, such as reductions, scatter, gather, and broadcast. The `comm` object is also used to perform point-to-point communication between processes.) - - -print(f"Global shape on rank {dndarray.comm.rank}: {dndarray.shape}") -print(f"Local shape on rank: {dndarray.comm.rank}: {dndarray.lshape}") - - -# You can perform a vast number of operations on DNDarrays distributed over multi-node and/or multi-GPU resources. Check out our [Numpy coverage tables](https://github.com/helmholtz-analytics/heat/blob/main/coverage_tables.md) to see what operations are already supported. -# -# The result of an operation on DNDarays will in most cases preserve the `split` or distribution axis of the respective operands. However, in some cases the split axis might change. For example, a transpose of a Heat array will equally transpose the split axis. Furthermore, a reduction operations, e.g. `sum()` that is performed across the split axis, might remove data partitions entirely. The respective function behaviors can be found in Heat's documentation. - - -# transpose -print(dndarray.T) - - -# reduction operation along the distribution axis -print(dndarray.sum(axis=0)) - - -other_dndarray = ht.arange(60, 120, split=0).reshape(5, 4, 3) # distributed reshape - -# element-wise multiplication -print(dndarray * other_dndarray) - - -# As we saw earlier, because the underlying data objects are PyTorch tensors, we can easily create DNDarrays on GPUs or move DNDarrays to GPUs. This allows us to perform distributed array operations on multi-GPU systems. -# -# So far we have demostrated small, easy-to-parallelize arithmetical operations. Let's move to linear algebra. Heat's `linalg` module supports a wide range of linear algebra operations, including matrix multiplication. Matrix multiplication is a very common operation data analysis, it is computationally intensive, and not trivial to parallelize. -# -# With Heat, you can perform matrix multiplication on distributed DNDarrays, and the operation will be parallelized across the MPI processes. Here on 4 GPUs: - - -import torch - -if torch.cuda.is_available(): - device = "gpu" -else: - device = "cpu" - -n, m = 400, 400 -x = ht.random.randn(n, m, split=0, device=device) # distributed RNG -y = ht.random.randn(m, n, split=None, device=device) -z = x @ y -print(z) - -# `ht.linalg.matmul` or `@` breaks down the matrix multiplication into a series of smaller `torch` matrix multiplications, which are then distributed across the MPI processes. This operation can be very communication-intensive on huge matrices that both require distribution, and users should choose the `split` axis carefully to minimize communication overhead. - -# You can experiment with sizes and the `split` parameter (distribution axis) for both matrices and time the result. Note that: -# - If you set **`split=None` for both matrices**, each process (in this case, each GPU) will attempt to multiply the entire matrices. Depending on the matrix sizes, the GPU memory might be insufficient. (And if you can multiply the matrices on a single GPU, it's much more efficient to stick to PyTorch's `torch.linalg.matmul` function.) -# - If **`split` is not None for both matrices**, each process will only hold a slice of the data, and will need to communicate data with other processes in order to perform the multiplication. This **introduces huge communication overhead**, but allows you to perform the multiplication on larger matrices than would fit in the memory of a single GPU. -# - If **`split` is None for one matrix and not None for the other**, the multiplication does not require communication, and the result will be distributed. If your data size allows it, you should always favor this option. -# -# Time the multiplication for different split parameters and see how the performance changes. -# -# - - -import time - -start = time.time() -z = x @ y -end = time.time() -print("runtime: ", end - start) - - -# Heat supports many linear algebra operations: -# ```bash -# >>> ht.linalg. -# ht.linalg.basics ht.linalg.hsvd_rtol( ht.linalg.projection( ht.linalg.triu( -# ht.linalg.cg( ht.linalg.inv( ht.linalg.qr( ht.linalg.vdot( -# ht.linalg.cross( ht.linalg.lanczos( ht.linalg.solver ht.linalg.vecdot( -# ht.linalg.det( ht.linalg.matmul( ht.linalg.svdtools ht.linalg.vector_norm( -# ht.linalg.dot( ht.linalg.matrix_norm( ht.linalg.trace( -# ht.linalg.hsvd( ht.linalg.norm( ht.linalg.transpose( -# ht.linalg.hsvd_rank( ht.linalg.outer( ht.linalg.tril( -# ``` -# -# and a lot more is in the works, including distributed eigendecompositions, SVD, and more. If the operation you need is not yet supported, leave us a note [here](tinyurl.com/demoissues) and we'll get back to you. - -# You can of course perform all operations on CPUs. You can leave out the `device` attribute entirely. - -# ### Interoperability -# -# We can easily create DNDarrays from PyTorch tensors and numpy ndarrays. We can also convert DNDarrays to PyTorch tensors and numpy ndarrays. This makes it easy to integrate Heat into existing PyTorch and numpy workflows. Here a basic example with xarrays: - - -import xarray as xr - -local_xr = xr.DataArray(dndarray.larray, dims=("z", "y", "x")) -# proceed with local xarray operations -print(local_xr) - - -# **NOTE:** this is not a distributed `xarray`, but local xarray objects on each rank. -# Work on [expanding xarray support](https://github.com/helmholtz-analytics/heat/pull/1183) is ongoing. -# - -# Heat will try to reuse the memory of the original array as much as possible. If you would prefer a copy with different memory, the ```copy``` keyword argument can be used when creating a DNDArray from other libraries. - - -import torch - -torch_array = torch.arange(5) -heat_array = ht.array(torch_array, copy=False) -heat_array[0] = -1 -print(torch_array) - -torch_array = torch.arange(5) -heat_array = ht.array(torch_array, copy=True) -heat_array[0] = -1 -print(torch_array) - - -# Interoperability is a key feature of Heat, and we are constantly working to increase Heat's compliance to the [Python array API standard](https://data-apis.org/array-api/latest/). As usual, please [let us know](tinyurl.com/demoissues) if you encounter any issues or have any feature requests. diff --git a/tutorials/hpc/02_internals/internals.py b/tutorials/hpc/02_internals/internals.py deleted file mode 100644 index 08f76080e0..0000000000 --- a/tutorials/hpc/02_internals/internals.py +++ /dev/null @@ -1,130 +0,0 @@ -import heat as ht -import torch - -# # Heat as infrastructure for MPI applications -# -# In this section, we'll go through some Heat-specific functionalities that simplify the implementation of a data-parallel application in Python. We'll demonstrate them on small arrays and 4 processes on a single cluster node, but the functionalities are indeed meant for a multi-node set up with huge arrays that cannot be processed on a single node. - - -# We already mentioned that the DNDarray object is "MPI-aware". Each DNDarray is associated to an MPI communicator, it is aware of the number of processes in the communicator, and it knows the rank of the process that owns it. -# - -a = ht.random.randn(7, 4, 3, split=0) -if a.comm.rank == 0: - print(f"a.com gets the communicator {a.comm} associated with DNDarray a") - -# MPI size = total number of processes -size = a.comm.size - -if a.comm.rank == 0: - print(f"a is distributed over {size} processes") - print(f"a is a distributed {a.ndim}-dimensional array with global shape {a.shape}") - - -# MPI rank = rank of each process -rank = a.comm.rank -# Local shape = shape of the data on each process -local_shape = a.lshape -print(f"Rank {rank} holds a slice of a with local shape {local_shape}") - - -# ### Distribution map -# -# In many occasions, when building a memory-distributed pipeline it will be convenient for each rank to have information on what ranks holds which slice of the distributed array. -# -# The `lshape_map` attribute of a DNDarray gathers (or, if possible, calculates) this info from all processes and stores it as metadata of the DNDarray. Because it is meant for internal use, it is stored in a torch tensor, not a DNDarray. -# -# The `lshape_map` tensor is a 2D tensor, where the first dimension is the number of processes and the second dimension is the number of dimensions of the array. Each row of the tensor contains the local shape of the array on a process. - - -lshape_map = a.lshape_map -if a.comm.rank == 0: - print(f"lshape_map available on any process: {lshape_map}") - -# Go back to where we created the DNDarray and and create `a` with a different split axis. See how the `lshape_map` changes. - -# ### Modifying the DNDarray distribution -# -# In a distributed pipeline, it is sometimes necessary to change the distribution of a DNDarray, when the array is not distributed in the most convenient way for the next operation / algorithm. -# -# Depending on your needs, you can choose between: -# - `DNDarray.redistribute_()`: This method keeps the original split axis, but redistributes the data of the DNDarray according to a "target map". -# - `DNDarray.resplit_()`: This method changes the split axis of the DNDarray. This is a more expensive operation, and should be used only when absolutely necessary. Depending on your needs and available resources, in some cases it might be wiser to keep a copy of the DNDarray with a different split axis. -# -# Let's see some examples. - - -# redistribute -target_map = a.lshape_map -target_map[:, a.split] = torch.tensor([1, 2, 2, 2]) -# in-place redistribution (see ht.redistribute for out-of-place) -a.redistribute_(target_map=target_map) - -# new lshape map after redistribution -a.lshape_map - -# local arrays after redistribution -a.larray - - -# resplit -a.resplit_(axis=1) - -a.lshape_map - - -# You can use the `resplit_` method (in-place), or `ht.resplit` (out-of-place) to change the distribution axis, but also to set the distribution axis to None. The latter corresponds to an MPI.Allgather operation that gathers the entire array on each process. This is useful when you've achieved a small enough data size that can be processed on a single device, and you want to avoid communication overhead. - - -# "un-split" distributed array -a.resplit_(axis=None) -# each process now holds a copy of the entire array - - -# The opposite is not true, i.e. you cannot use `resplit_` to distribute an array with split=None. In that case, you must use the `ht.array()` factory function: - - -# make `a` split again -a = ht.array(a, split=0) - - -# ### Making disjoint data into a global DNDarray -# -# Another common occurrence in a data-parallel pipeline: you have addressed the embarassingly-parallel part of your algorithm with any array framework, each process working independently from the others. You now want to perform a non-embarassingly-parallel operation on the entire dataset, with Heat as a backend. -# -# You can use the `ht.array` factory function with the `is_split` argument to create a DNDarray from a disjoint (on each MPI process) set of arrays. The `is_split` argument indicates the axis along which the disjoint data is to be "joined" into a global, distributed DNDarray. - - -# create some random local arrays on each process -import numpy as np - -local_array = np.random.rand(3, 4) - -# join them into a distributed array -a_0 = ht.array(local_array, is_split=0) -a_0.shape - - -# Change the cell above and join the arrays along a different axis. Note that the shapes of the local arrays must be consistent along the non-split axes. They can differ along the split axis. - -# The `ht.array` function takes any data object as an input that can be converted to a torch tensor. - -# Once you've made your disjoint data into a DNDarray, you can apply any Heat operation or algorithm to it and exploit the cumulative RAM of all the processes in the communicator. - -# You can access the MPI communication functionalities of the DNDarray through the `comm` attribute, i.e.: -# -# ```python -# # these are just examples, this cell won't do anything -# a.comm.Allreduce(a, b, op=MPI.SUM) -# -# a.comm.Allgather(a, b) -# a.comm.Isend(a, dest=1, tag=0) -# ``` -# -# etc. - -# In the next notebooks, we'll show you how we use Heat's distributed-array infrastructure to scale complex data analysis workflows to large datasets and high-performance computing resources. -# -# - [Data loading and preprocessing](4_loading_preprocessing.ipynb) -# - [Matrix factorization algorithms](5_matrix_factorizations.ipynb) -# - [Clustering algorithms](6_clustering.ipynb) diff --git a/tutorials/hpc/02_loading_preprocessing/01_IO.py b/tutorials/hpc/02_loading_preprocessing/01_IO.py new file mode 100644 index 0000000000..6452d086f4 --- /dev/null +++ b/tutorials/hpc/02_loading_preprocessing/01_IO.py @@ -0,0 +1,32 @@ +# # Loading and Preprocessing +# +# ### Refresher +# +# Using PyTorch as compute engine and mpi4py for communication, Heat implements a number of array operations and algorithms that are optimized for memory-distributed data volumes. This allows you to tackle datasets that are too large for single-node (or worse, single-GPU) processing. +# +# As opposed to task-parallel frameworks, Heat takes a data-parallel approach, meaning that each "worker" or MPI process performs the same tasks on different slices of the data. Many operations and algorithms are not embarassingly parallel, and involve data exchange between processes. Heat operations and algorithms are designed to minimize this communication overhead, and to make it transparent to the user. +# +# In other words: +# - you don't have to worry about optimizing data chunk sizes; +# - you don't have to make sure your research problem is embarassingly parallel, or artificially make your dataset smaller so your RAM is sufficient; +# - you do have to make sure that you have sufficient **overall** RAM to run your global task (e.g. number of nodes / GPUs). + +# The following shows some I/O and preprocessing examples. We'll use small datasets here as each of us only has access to one node only. + +# ### I/O +# +# Let's start with loading a data set. Heat supports reading and writing from/into shared memory for a number of formats, including HDF5, NetCDF, and because we love scientists, csv. Check out the `ht.load` and `ht.save` functions for more details. Here we will load data in [HDF5 format](https://en.wikipedia.org/wiki/Hierarchical_Data_Format). +# +# Now let's import `heat` and load a data set. + +import heat as ht + +# Some random data for small scale tests +iris = ht.load("iris.csv", sep=";", split=0) +print(iris) + +# We have loaded the entire data onto 4 MPI processes, each with 12 cores. We have created `X` with `split=0`, so each process stores evenly-sized slices of the data along dimension 0. + +# similar for HDF5 +X = ht.load_hdf5("path_to_data/sbdb_asteroids.h5", device="gpu", dataset="data", split=0) +print(X.shape) diff --git a/tutorials/hpc/03_loading_preprocessing/loading_preprocessing_script.py b/tutorials/hpc/02_loading_preprocessing/02_preprocessing.py similarity index 57% rename from tutorials/hpc/03_loading_preprocessing/loading_preprocessing_script.py rename to tutorials/hpc/02_loading_preprocessing/02_preprocessing.py index 3d37992c21..33c8729765 100644 --- a/tutorials/hpc/03_loading_preprocessing/loading_preprocessing_script.py +++ b/tutorials/hpc/02_loading_preprocessing/02_preprocessing.py @@ -1,35 +1,6 @@ -# # Loading and Preprocessing -# -# ### Refresher -# -# Using PyTorch as compute engine and mpi4py for communication, Heat implements a number of array operations and algorithms that are optimized for memory-distributed data volumes. This allows you to tackle datasets that are too large for single-node (or worse, single-GPU) processing. -# -# As opposed to task-parallel frameworks, Heat takes a data-parallel approach, meaning that each "worker" or MPI process performs the same tasks on different slices of the data. Many operations and algorithms are not embarassingly parallel, and involve data exchange between processes. Heat operations and algorithms are designed to minimize this communication overhead, and to make it transparent to the user. -# -# In other words: -# - you don't have to worry about optimizing data chunk sizes; -# - you don't have to make sure your research problem is embarassingly parallel, or artificially make your dataset smaller so your RAM is sufficient; -# - you do have to make sure that you have sufficient **overall** RAM to run your global task (e.g. number of nodes / GPUs). - -# The following shows some I/O and preprocessing examples. We'll use small datasets here as each of us only has access to one node only. - -# ### I/O -# -# Let's start with loading a data set. Heat supports reading and writing from/into shared memory for a number of formats, including HDF5, NetCDF, and because we love scientists, csv. Check out the `ht.load` and `ht.save` functions for more details. Here we will load data in [HDF5 format](https://en.wikipedia.org/wiki/Hierarchical_Data_Format). -# -# This particular example data set (generated from all Asteroids from the [JPL Small Body Database](https://ssd.jpl.nasa.gov/sb/)) is really small, but it allows to demonstrate the basic functionality of Heat. -# - -# The above cell should return [0, 1, 2, 3]. -# -# Now let's import `heat` and load the data set. - import heat as ht -# X = ht.load_hdf5("../data/sbdb_asteroids.h5",dtype=ht.float64,dataset="data",split=0) - -# Some random data for small scale tests -X = ht.random.randn(1000, 3, split=0) +X = ht.random.randn(1000, 3, split=0, device="gpu") # We have loaded the entire data onto 4 MPI processes, each with 12 cores. We have created `X` with `split=0`, so each process stores evenly-sized slices of the data along dimension 0. diff --git a/tutorials/hpc/05_clustering/iris.csv b/tutorials/hpc/02_loading_preprocessing/iris.csv similarity index 100% rename from tutorials/hpc/05_clustering/iris.csv rename to tutorials/hpc/02_loading_preprocessing/iris.csv From c00720649443fdf1449702e8c4c89b4b8dec2f5c Mon Sep 17 00:00:00 2001 From: Hoppe Date: Mon, 10 Jun 2024 16:32:25 +0200 Subject: [PATCH 4/7] further modifications --- .../hpc/02_loading_preprocessing/01_IO.py | 12 +- .../matrix_factorizations.py | 15 +- tutorials/hpc/04_clustering/clustering.py | 68 +++++++ tutorials/hpc/05_clustering/clustering.py | 182 ------------------ .../hpc/05_your_turn/now_its_your_turn.py | 48 +++++ 5 files changed, 129 insertions(+), 196 deletions(-) rename tutorials/hpc/{04_matrix_factorizations => 03_matrix_factorizations}/matrix_factorizations.py (88%) create mode 100644 tutorials/hpc/04_clustering/clustering.py delete mode 100644 tutorials/hpc/05_clustering/clustering.py create mode 100644 tutorials/hpc/05_your_turn/now_its_your_turn.py diff --git a/tutorials/hpc/02_loading_preprocessing/01_IO.py b/tutorials/hpc/02_loading_preprocessing/01_IO.py index 6452d086f4..0c9403767b 100644 --- a/tutorials/hpc/02_loading_preprocessing/01_IO.py +++ b/tutorials/hpc/02_loading_preprocessing/01_IO.py @@ -28,5 +28,13 @@ # We have loaded the entire data onto 4 MPI processes, each with 12 cores. We have created `X` with `split=0`, so each process stores evenly-sized slices of the data along dimension 0. # similar for HDF5 -X = ht.load_hdf5("path_to_data/sbdb_asteroids.h5", device="gpu", dataset="data", split=0) -print(X.shape) + +# first, we generate some data +X = ht.random.randn(10000, 100, split=0) + +# ... and save it to file +ht.save(X, "~/mydata.h5", "mydata", mode="a") + +# ... then we can load it again +Y = ht.load_hdf5("~/mydata.h5", device="gpu", dataset="mydata", split=0) +print(ht.allclose(X, Y)) diff --git a/tutorials/hpc/04_matrix_factorizations/matrix_factorizations.py b/tutorials/hpc/03_matrix_factorizations/matrix_factorizations.py similarity index 88% rename from tutorials/hpc/04_matrix_factorizations/matrix_factorizations.py rename to tutorials/hpc/03_matrix_factorizations/matrix_factorizations.py index 1a77b2436e..1543c81efe 100644 --- a/tutorials/hpc/04_matrix_factorizations/matrix_factorizations.py +++ b/tutorials/hpc/03_matrix_factorizations/matrix_factorizations.py @@ -60,9 +60,7 @@ import heat as ht -X = ht.load_hdf5( - "/p/scratch/training2404/data/JPL_SBDB/sbdb_asteroids.h5", dataset="data", split=0 -).T +X = ht.load_hdf5("~/mydata.h5", dataset="mydata", split=0).T # Note that due to the transpose, `X` is distributed along the columns now; this is required by the hSVD-algorithm. @@ -83,19 +81,12 @@ # Once we have computed the truncated SVD, we can use it to approximate the original data matrix `X` by the truncated matrix `X_r`. # -# Check out the plot below to see how Heat's truncated SVD algorithm scales with the number of MPI processes and size of the dataset. - -#
-# -# -# -#
-# +# Check out https://helmholtz-analytics.github.io/heat/2023/06/16/new-feature-hsvd.html to see how Heat's truncated SVD algorithm scales with the number of MPI processes and size of the dataset. # ### Other factorizations # # Other common factorization algorithms are supported in Heat, such as: -# - QR decomposition (`heat.linalg.qr`), +# - QR decomposition (`heat.linalg.qr`) # - Lanczos algorithm for computing the largest eigenvalues and corresponding eigenvectors (`heat.linalg.lanczos`) # # Check out our [`linalg` PRs](https://github.com/helmholtz-analytics/heat/pulls?q=is%3Aopen+is%3Apr+label%3Alinalg) to see what's in progress. diff --git a/tutorials/hpc/04_clustering/clustering.py b/tutorials/hpc/04_clustering/clustering.py new file mode 100644 index 0000000000..85c6e2c5e3 --- /dev/null +++ b/tutorials/hpc/04_clustering/clustering.py @@ -0,0 +1,68 @@ +# Cluster Analysis +# ================ +# +# This tutorial is an interactive version of our static [clustering tutorial on ReadTheDocs](https://heat.readthedocs.io/en/stable/tutorial_clustering.html). +# +# We will demonstrate memory-distributed analysis with k-means and k-medians from the ``heat.cluster`` module. As usual, we will run the analysis on a small dataset for demonstration. We need to have an `ipcluster` running to distribute the computation. +# +# We will use matplotlib for visualization of data and results. + + +import heat as ht + +# The Iris Dataset +# ------------------------------ +# The _iris_ dataset is a well known example for clustering analysis. It contains 4 measured features for samples from +# three different types of iris flowers. A subset of 150 samples is included in formats h5, csv and netcdf in the [Heat repository under 'heat/heat/datasets'](https://github.com/helmholtz-analytics/heat/tree/main/heat/datasets), and can be loaded in a distributed manner with Heat's parallel dataloader. +# +# **NOTE: you might have to change the path to the dataset in the following cell.** + +iris = ht.load("~/heat/tutorials/hpc/02_loading_preprocessing/iris.csv", sep=";", split=0) + + +# Feel free to try out the other [loading options](https://heat.readthedocs.io/en/stable/autoapi/heat/core/io/index.html#heat.core.io.load) as well. +# +# Fitting the dataset with `kmeans`: + +k = 3 +kmeans = ht.cluster.KMeans(n_clusters=k, init="kmeans++") +kmeans.fit(iris) + +# Let's see what the results are. In theory, there are 50 samples of each of the 3 iris types: setosa, versicolor and virginica. We will plot the results in a 3D scatter plot, coloring the samples according to the assigned cluster. + +labels = kmeans.predict(iris).squeeze() + +# Select points assigned to clusters c1, c2 and c3 +c1 = iris[ht.where(labels == 0), :] +c2 = iris[ht.where(labels == 1), :] +c3 = iris[ht.where(labels == 2), :] +# After slicing, the arrays are not distributed equally among the processes anymore; we need to balance +# TODO is balancing really necessary? +c1.balance_() +c2.balance_() +c3.balance_() + +print( + f"Number of points assigned to c1: {c1.shape[0]} \n" + f"Number of points assigned to c2: {c2.shape[0]} \n" + f"Number of points assigned to c3: {c3.shape[0]}" +) + + +# compare Heat results with sklearn +from sklearn.cluster import KMeans +import sklearn.datasets + +k = 3 +iris_sk = sklearn.datasets.load_iris().data +kmeans_sk = KMeans(n_clusters=k, init="k-means++").fit(iris_sk) +labels_sk = kmeans_sk.predict(iris_sk) + +c1_sk = iris_sk[labels_sk == 0, :] +c2_sk = iris_sk[labels_sk == 1, :] +c3_sk = iris_sk[labels_sk == 2, :] +print( + f"Number of points assigned to c1: {c1_sk.shape[0]} \n" + f"Number of points assigned to c2: {c2_sk.shape[0]} \n" + f"Number of points assigned to c3: {c3_sk.shape[0]}" +) diff --git a/tutorials/hpc/05_clustering/clustering.py b/tutorials/hpc/05_clustering/clustering.py deleted file mode 100644 index 24df3da224..0000000000 --- a/tutorials/hpc/05_clustering/clustering.py +++ /dev/null @@ -1,182 +0,0 @@ -# Cluster Analysis -# ================ -# -# This tutorial is an interactive version of our static [clustering tutorial on ReadTheDocs](https://heat.readthedocs.io/en/stable/tutorial_clustering.html). -# -# We will demonstrate memory-distributed analysis with k-means and k-medians from the ``heat.cluster`` module. As usual, we will run the analysis on a small dataset for demonstration. We need to have an `ipcluster` running to distribute the computation. -# -# We will use matplotlib for visualization of data and results. - - -import heat as ht - - -# Spherical Clouds of Datapoints -# ------------------------------ -# For a simple demonstration of the clustering process and the differences between the algorithms, we will create an -# artificial dataset, consisting of two circularly shaped clusters positioned at $(x_1=2, y_1=2)$ and $(x_2=-2, y_2=-2)$ in 2D space. -# For each cluster we will sample 100 arbitrary points from a circle with radius of $R = 1.0$ by drawing random numbers -# for the spherical coordinates $( r\in [0,R], \phi \in [0,2\pi])$, translating these to cartesian coordinates -# and shifting them by $+2$ for cluster ``c1`` and $-2$ for cluster ``c2``. The resulting concatenated dataset ``data`` has shape -# $(200, 2)$ and is distributed among the ``p`` processes along axis 0 (sample axis). - - -num_ele = 100 -R = 1.0 - -# Create default spherical point cloud -# Sample radius between 0 and 1, and phi between 0 and 2pi -r = ht.random.rand(num_ele, split=0) * R -phi = ht.random.rand(num_ele, split=0) * 2 * ht.constants.PI - -# Transform spherical coordinates to cartesian coordinates -x = r * ht.cos(phi) -y = r * ht.sin(phi) - - -# Stack the sampled points and shift them to locations (2,2) and (-2, -2) -cluster1 = ht.stack((x + 2, y + 2), axis=1) -cluster2 = ht.stack((x - 2, y - 2), axis=1) - -data = ht.concatenate((cluster1, cluster2), axis=0) - - -# Let's plot the data for illustration. In order to do so with matplotlib, we need to unsplit the data (gather it from -# all processes) and transform it into a numpy array. Plotting can only be done on rank 0. - - -data_np = ht.resplit(data, axis=None).numpy() - - -# import matplotlib.pyplot as plt -# plt.plot(data_np[:,0], data_np[:,1], 'bo') - - -# Now we perform the clustering analysis with kmeans. We chose 'kmeans++' as an intelligent way of sampling the -# initial centroids. - -kmeans = ht.cluster.KMeans(n_clusters=2, init="kmeans++") -labels = kmeans.fit_predict(data).squeeze() -centroids = kmeans.cluster_centers_ - -# Select points assigned to clusters c1 and c2 -c1 = data[ht.where(labels == 0), :] -c2 = data[ht.where(labels == 1), :] -# After slicing, the arrays are no longer distributed evenly among the processes; we might need to balance the load -c1.balance_() # in-place operation -c2.balance_() - -print( - f"Number of points assigned to c1: {c1.shape[0]} \n" - f"Number of points assigned to c2: {c2.shape[0]} \n" - f"Centroids = {centroids}" -) - - -# Let's plot the assigned clusters and the respective centroids: - -# just for plotting: collect all the data on each process and extract the numpy arrays. This will copy data to CPU if necessary. -c1_np = c1.numpy() -c2_np = c2.numpy() - -""" -import matplotlib.pyplot as plt -# plotting on 1 process only -plt.plot(c1_np[:,0], c1_np[:,1], 'x', color='#f0781e') -plt.plot(c2_np[:,0], c2_np[:,1], 'x', color='#5a696e') -plt.plot(centroids[0,0],centroids[0,1], '^', markersize=10, markeredgecolor='black', color='#f0781e' ) -plt.plot(centroids[1,0],centroids[1,1], '^', markersize=10, markeredgecolor='black',color='#5a696e') -plt.savefig('centroids_1.png') -""" - - -# We can also cluster the data with kmedians. The respective advanced initial centroid sampling is called 'kmedians++'. - -kmedians = ht.cluster.KMedians(n_clusters=2, init="kmedians++") -labels = kmedians.fit_predict(data).squeeze() -centroids = kmedians.cluster_centers_ - -# Select points assigned to clusters c1 and c2 -c1 = data[ht.where(labels == 0), :] -c2 = data[ht.where(labels == 1), :] -# After slicing, the arrays are not distributed equally among the processes anymore; we need to balance -c1.balance_() -c2.balance_() - -print( - f"Number of points assigned to c1: {c1.shape[0]} \n" - f"Number of points assigned to c2: {c2.shape[0]} \n" - f"Centroids = {centroids}" -) - - -# Plotting the assigned clusters and the respective centroids: - -c1_np = c1.numpy() -c2_np = c2.numpy() - -""" -plt.plot(c1_np[:,0], c1_np[:,1], 'x', color='#f0781e') -plt.plot(c2_np[:,0], c2_np[:,1], 'x', color='#5a696e') -plt.plot(centroids[0,0],centroids[0,1], '^', markersize=10, markeredgecolor='black', color='#f0781e' ) -plt.plot(centroids[1,0],centroids[1,1], '^', markersize=10, markeredgecolor='black',color='#5a696e') -plt.savefig('centroids_2.png') -""" - - -# The Iris Dataset -# ------------------------------ -# The _iris_ dataset is a well known example for clustering analysis. It contains 4 measured features for samples from -# three different types of iris flowers. A subset of 150 samples is included in formats h5, csv and netcdf in the [Heat repository under 'heat/heat/datasets'](https://github.com/helmholtz-analytics/heat/tree/main/heat/datasets), and can be loaded in a distributed manner with Heat's parallel dataloader. -# -# **NOTE: you might have to change the path to the dataset in the following cell.** - -iris = ht.load("iris.csv", sep=";", split=0) - - -# Feel free to try out the other [loading options](https://heat.readthedocs.io/en/stable/autoapi/heat/core/io/index.html#heat.core.io.load) as well. -# -# Fitting the dataset with `kmeans`: - -k = 3 -kmeans = ht.cluster.KMeans(n_clusters=k, init="kmeans++") -kmeans.fit(iris) - -# Let's see what the results are. In theory, there are 50 samples of each of the 3 iris types: setosa, versicolor and virginica. We will plot the results in a 3D scatter plot, coloring the samples according to the assigned cluster. - -labels = kmeans.predict(iris).squeeze() - -# Select points assigned to clusters c1, c2 and c3 -c1 = iris[ht.where(labels == 0), :] -c2 = iris[ht.where(labels == 1), :] -c3 = iris[ht.where(labels == 2), :] -# After slicing, the arrays are not distributed equally among the processes anymore; we need to balance -# TODO is balancing really necessary? -c1.balance_() -c2.balance_() -c3.balance_() - -print( - f"Number of points assigned to c1: {c1.shape[0]} \n" - f"Number of points assigned to c2: {c2.shape[0]} \n" - f"Number of points assigned to c3: {c3.shape[0]}" -) - - -# compare Heat results with sklearn -from sklearn.cluster import KMeans -import sklearn.datasets - -k = 3 -iris_sk = sklearn.datasets.load_iris().data -kmeans_sk = KMeans(n_clusters=k, init="k-means++").fit(iris_sk) -labels_sk = kmeans_sk.predict(iris_sk) - -c1_sk = iris_sk[labels_sk == 0, :] -c2_sk = iris_sk[labels_sk == 1, :] -c3_sk = iris_sk[labels_sk == 2, :] -print( - f"Number of points assigned to c1: {c1_sk.shape[0]} \n" - f"Number of points assigned to c2: {c2_sk.shape[0]} \n" - f"Number of points assigned to c3: {c3_sk.shape[0]}" -) diff --git a/tutorials/hpc/05_your_turn/now_its_your_turn.py b/tutorials/hpc/05_your_turn/now_its_your_turn.py new file mode 100644 index 0000000000..a20e67946e --- /dev/null +++ b/tutorials/hpc/05_your_turn/now_its_your_turn.py @@ -0,0 +1,48 @@ +import heat as ht +import numpy as np +import h5py + +# Now its your turn! Download one of the following three data sets and play around with it. +# Possible ideas: +# get familiar with the data: shape, min, max, avg, std (possibly along axes?) +# try SVD and/or QR to detect linear dependence +# K-Means Clustering (Asteroids, CERN?) +# Lasso (CERN?) +# n-dim FFT (CAMELS?)... + + +# "Asteroids": Asteroids of the Solar System +# Use +# ``` +# wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1E6mrMi7lL-UoHmZMTicVi0I2IdOKNn2-' -O ~/sbdb_asteroids.h5 +# ``` +# to download an example data set (92MB) consisting of the asteroids from the JPL Small Body Database (https://ssd.jpl.nasa.gov/tools/sbdb_lookup.html#/) + +X = ht.load_hdf5("~/sbdb_asteroids.h5", device="gpu", dataset="data", split=0) + +# ... to be completed ... + +# "CAMELS": 1000 simulated universes on 128 x 128 x 128 grids +# Take a bunch of 1000 simulated universes from the CAMELS data set (8GB): +# ``` +# wget https://users.flatironinstitute.org/~fvillaescusa/priv/DEPnzxoWlaTQ6CjrXqsm0vYi8L7Jy/CMD/3D_grids/data/Nbody/Grids_Mtot_Nbody_Astrid_LH_128_z=0.0.npy ~/Grids_Mtot_Nbody_Astrid_LH_128_z=0.0.npy +# ``` +# load them in NumPy, convert to PyTorch and Heat... + +X_np = np.load("~/Grids_Mtot_Nbody_Astrid_LH_128_z=0.0.npy") + +# ... to be completed ... + +# "CERN": A particle physics data set from CERN +# Take a small part of the ATLAS Top Tagging Data Set from CERN (7.6GB, actually the "test"-part; the "train" part is much larger...) +# ``` +# wget https://opendata.cern.ch/record/15013/files/test.h5 ~/test.h5 +# ``` +# and load it directly into Heat (watch out: the h5-file contains different data sets that need to be stacked...) + +filename = "~/test.h5" +with h5py.File(filename, "r") as f: + features = f.keys() + arrays = [ht.load_hdf5(filename, feature, split=0) for feature in features] + +# ... to be completed ... From 5120b9f3bad8e895969e3c3808585a289ea4122e Mon Sep 17 00:00:00 2001 From: Hoppe Date: Mon, 10 Jun 2024 17:24:52 +0200 Subject: [PATCH 5/7] added batch script --- tutorials/hpc/02_loading_preprocessing/01_IO.py | 2 +- tutorials/hpc/run_script.sh | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) create mode 100644 tutorials/hpc/run_script.sh diff --git a/tutorials/hpc/02_loading_preprocessing/01_IO.py b/tutorials/hpc/02_loading_preprocessing/01_IO.py index 0c9403767b..ea8aec1545 100644 --- a/tutorials/hpc/02_loading_preprocessing/01_IO.py +++ b/tutorials/hpc/02_loading_preprocessing/01_IO.py @@ -22,7 +22,7 @@ import heat as ht # Some random data for small scale tests -iris = ht.load("iris.csv", sep=";", split=0) +iris = ht.load("~/heat/tutorials/02_loading_preprocessing/iris.csv", sep=";", split=0) print(iris) # We have loaded the entire data onto 4 MPI processes, each with 12 cores. We have created `X` with `split=0`, so each process stores evenly-sized slices of the data along dimension 0. diff --git a/tutorials/hpc/run_script.sh b/tutorials/hpc/run_script.sh new file mode 100644 index 0000000000..3386fdab32 --- /dev/null +++ b/tutorials/hpc/run_script.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +#SBATCH --partition=normal +#SBATCH --reservation=haicon +#SBATCH --nodes=1 +#SBATCH --tasks-per-node=4 +#SBATCH --cpus-per-task=12 +#SBATCH --gres=gpu:4 +#SBATCH --time="00:01:00" + +export MKL_NUM_THREADS=$SLURM_CPUS_PER_TASK +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK + +srun python ~/heat/tutorials/hpc/01_basics/01_basics_dndarrays.py From cfeabecb3e03e322b869fa11cd65f0164ce8456b Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Mon, 10 Jun 2024 18:40:20 +0200 Subject: [PATCH 6/7] Update tutorials/hpc/01_basics/05_basics_broadcast.py --- tutorials/hpc/01_basics/05_basics_broadcast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/hpc/01_basics/05_basics_broadcast.py b/tutorials/hpc/01_basics/05_basics_broadcast.py index 0eabb4fed5..20cd843ec8 100644 --- a/tutorials/hpc/01_basics/05_basics_broadcast.py +++ b/tutorials/hpc/01_basics/05_basics_broadcast.py @@ -15,5 +15,5 @@ ) b = ht.arange(4) print( - f"broadcasing across the first dimension of {a} with shape = (3, 4) and {b} with shape = (4): {a+b}" + f"broadcasting across the first dimension of {a} with shape = (3, 4) and {b} with shape = (4): {a + b}" ) From ddc1c7d2d527ca13a3f1cea33d80e5ed2403f63e Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Mon, 10 Jun 2024 18:40:39 +0200 Subject: [PATCH 7/7] Update tutorials/hpc/02_loading_preprocessing/02_preprocessing.py --- tutorials/hpc/02_loading_preprocessing/02_preprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/hpc/02_loading_preprocessing/02_preprocessing.py b/tutorials/hpc/02_loading_preprocessing/02_preprocessing.py index 33c8729765..d3195ab5c1 100644 --- a/tutorials/hpc/02_loading_preprocessing/02_preprocessing.py +++ b/tutorials/hpc/02_loading_preprocessing/02_preprocessing.py @@ -12,7 +12,7 @@ # print global metadata once only if X.comm.rank == 0: print(f"X is a {X.ndim}-dimensional array with shape{X.shape}") - print(f"X takes up {X.nbytes/1e6} MB of memory.") + print(f"X takes up {X.nbytes / 1e6} MB of memory.") # X is a matrix of shape *(datapoints, features)*. #