diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index ba46e60b4..7c337761c 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -412,6 +412,7 @@ add_library( src/neighbors/nn_descent_float.cu src/neighbors/nn_descent_int8.cu src/neighbors/nn_descent_uint8.cu + src/neighbors/reachability.cu src/neighbors/refine/detail/refine_device_float_float.cu src/neighbors/refine/detail/refine_device_half_float.cu src/neighbors/refine/detail/refine_device_int8_t_float.cu diff --git a/cpp/include/cuvs/neighbors/reachability.hpp b/cpp/include/cuvs/neighbors/reachability.hpp new file mode 100644 index 000000000..c7ce9474d --- /dev/null +++ b/cpp/include/cuvs/neighbors/reachability.hpp @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include + +namespace cuvs::neighbors::reachability { + +/** + * @defgroup reachability_cpp Mutual Reachability + * @{ + */ +/** + * Constructs a mutual reachability graph, which is a k-nearest neighbors + * graph projected into mutual reachability space using the following + * function for each data point, where core_distance is the distance + * to the kth neighbor: max(core_distance(a), core_distance(b), d(a, b)) + * + * Unfortunately, points in the tails of the pdf (e.g. in sparse regions + * of the space) can have very large neighborhoods, which will impact + * nearby neighborhoods. Because of this, it's possible that the + * radius for points in the main mass, which might have a very small + * radius initially, to expand very large. As a result, the initial + * knn which was used to compute the core distances may no longer + * capture the actual neighborhoods after projection into mutual + * reachability space. + * + * For the experimental version, we execute the knn twice- once + * to compute the radii (core distances) and again to capture + * the final neighborhoods. Future iterations of this algorithm + * will work improve upon this "exact" version, by using + * more specialized data structures, such as space-partitioning + * structures. It has also been shown that approximate nearest + * neighbors can yield reasonable neighborhoods as the + * data sizes increase. + * + * @param[in] handle raft handle for resource reuse + * @param[in] X input data points (size m * n) + * @param[in] min_samples this neighborhood will be selected for core distances + * @param[out] indptr CSR indptr of output knn graph (size m + 1) + * @param[out] core_dists output core distances array (size m) + * @param[out] out COO object, uninitialized on entry, on exit it stores the + * (symmetrized) maximum reachability distance for the k nearest + * neighbors. + * @param[in] metric distance metric to use, default Euclidean + * @param[in] alpha weight applied when internal distance is chosen for + * mutual reachability (value of 1.0 disables the weighting) + */ +void mutual_reachability_graph( + const raft::resources& handle, + raft::device_matrix_view X, + int min_samples, + raft::device_vector_view indptr, + raft::device_vector_view core_dists, + raft::sparse::COO& out, + cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2SqrtExpanded, + float alpha = 1.0); +/** + * @} + */ +} // namespace cuvs::neighbors::reachability diff --git a/cpp/src/neighbors/detail/knn_brute_force.cuh b/cpp/src/neighbors/detail/knn_brute_force.cuh index e3f7acc96..dea1782c6 100644 --- a/cpp/src/neighbors/detail/knn_brute_force.cuh +++ b/cpp/src/neighbors/detail/knn_brute_force.cuh @@ -62,7 +62,10 @@ namespace cuvs::neighbors::detail { * Calculates brute force knn, using a fixed memory budget * by tiling over both the rows and columns of pairwise_distances */ -template +template void tiled_brute_force_knn(const raft::resources& handle, const ElementType* search, // size (m ,d) const ElementType* index, // size (n ,d) @@ -78,7 +81,8 @@ void tiled_brute_force_knn(const raft::resources& handle, size_t max_col_tile_size = 0, const DistanceT* precomputed_index_norms = nullptr, const DistanceT* precomputed_search_norms = nullptr, - const uint32_t* filter_bitmap = nullptr) + const uint32_t* filter_bitmap = nullptr, + DistanceEpilogue distance_epilogue = raft::identity_op()) { // Figure out the number of rows/cols to tile for size_t tile_rows = 0; @@ -209,7 +213,8 @@ void tiled_brute_force_knn(const raft::resources& handle, IndexType col = j + (idx % current_centroid_size); cuvs::distance::detail::ops::l2_exp_cutlass_op l2_op(sqrt); - return l2_op(row_norms[row], col_norms[col], dist[idx]); + auto val = l2_op(row_norms[row], col_norms[col], dist[idx]); + return distance_epilogue(val, row, col); }); } else if (metric == cuvs::distance::DistanceType::CosineExpanded) { auto row_norms = precomputed_search_norms ? precomputed_search_norms : search_norms.data(); @@ -223,8 +228,22 @@ void tiled_brute_force_knn(const raft::resources& handle, IndexType row = i + (idx / current_centroid_size); IndexType col = j + (idx % current_centroid_size); auto val = DistanceT(1.0) - dist[idx] / DistanceT(row_norms[row] * col_norms[col]); - return val; + return distance_epilogue(val, row, col); }); + } else { + // if we're not l2 distance, and we have a distance epilogue - run it now + if constexpr (!std::is_same_v) { + auto distances_ptr = temp_distances.data(); + raft::linalg::map_offset( + handle, + raft::make_device_vector_view(temp_distances.data(), + current_query_size * current_centroid_size), + [=] __device__(size_t idx) { + IndexType row = i + (idx / current_centroid_size); + IndexType col = j + (idx % current_centroid_size); + return distance_epilogue(distances_ptr[idx], row, col); + }); + } } if (filter_bitmap != nullptr) { diff --git a/cpp/src/neighbors/detail/reachability.cuh b/cpp/src/neighbors/detail/reachability.cuh new file mode 100644 index 000000000..4932c36ab --- /dev/null +++ b/cpp/src/neighbors/detail/reachability.cuh @@ -0,0 +1,281 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "./knn_brute_force.cuh" + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +namespace cuvs::neighbors::detail::reachability { + +/** + * Extract core distances from KNN graph. This is essentially + * performing a knn_dists[:,min_pts] + * @tparam value_idx data type for integrals + * @tparam value_t data type for distance + * @tparam tpb block size for kernel + * @param[in] knn_dists knn distance array (size n * k) + * @param[in] min_samples this neighbor will be selected for core distances + * @param[in] n_neighbors the number of neighbors of each point in the knn graph + * @param[in] n number of samples + * @param[out] out output array (size n) + * @param[in] stream stream for which to order cuda operations + */ +template +void core_distances( + value_t* knn_dists, int min_samples, int n_neighbors, size_t n, value_t* out, cudaStream_t stream) +{ + ASSERT(n_neighbors >= min_samples, + "the size of the neighborhood should be greater than or equal to min_samples"); + + auto exec_policy = rmm::exec_policy(stream); + + auto indices = thrust::make_counting_iterator(0); + + thrust::transform(exec_policy, indices, indices + n, out, [=] __device__(value_idx row) { + return knn_dists[row * n_neighbors + (min_samples - 1)]; + }); +} + +/** + * Wraps the brute force knn API, to be used for both training and prediction + * @tparam value_idx data type for integrals + * @tparam value_t data type for distance + * @param[in] handle raft handle for resource reuse + * @param[in] X input data points (size m * n) + * @param[out] inds nearest neighbor indices (size n_search_items * k) + * @param[out] dists nearest neighbor distances (size n_search_items * k) + * @param[in] m number of rows in X + * @param[in] n number of columns in X + * @param[in] search_items array of items to search of dimensionality D (size n_search_items * n) + * @param[in] n_search_items number of rows in search_items + * @param[in] k number of nearest neighbors + * @param[in] metric distance metric to use + */ +template +void compute_knn(const raft::resources& handle, + const value_t* X, + value_idx* inds, + value_t* dists, + size_t m, + size_t n, + const value_t* search_items, + size_t n_search_items, + int k, + cuvs::distance::DistanceType metric) +{ + auto stream = raft::resource::get_cuda_stream(handle); + auto exec_policy = raft::resource::get_thrust_policy(handle); + std::vector inputs; + inputs.push_back(const_cast(X)); + + std::vector sizes; + sizes.push_back(m); + + // the tiled_brute_force_knn code only works with int64 indices, convert + rmm::device_uvector int64_indices(k * n_search_items, stream); + + // perform knn + tiled_brute_force_knn( + handle, X, search_items, m, n_search_items, n, k, dists, int64_indices.data(), metric); + + // convert from current knn's 64-bit to 32-bit. + thrust::transform(exec_policy, + int64_indices.data(), + int64_indices.data() + int64_indices.size(), + inds, + [] __device__(int64_t in) -> value_idx { return in; }); +} + +/* + @brief Internal function for CPU->GPU interop + to compute core_dists +*/ +template +void _compute_core_dists(const raft::resources& handle, + const value_t* X, + value_t* core_dists, + size_t m, + size_t n, + cuvs::distance::DistanceType metric, + int min_samples) +{ + RAFT_EXPECTS(metric == cuvs::distance::DistanceType::L2SqrtExpanded, + "Currently only L2 expanded distance is supported"); + + auto stream = raft::resource::get_cuda_stream(handle); + + rmm::device_uvector inds(min_samples * m, stream); + rmm::device_uvector dists(min_samples * m, stream); + + // perform knn + compute_knn(handle, X, inds.data(), dists.data(), m, n, X, m, min_samples, metric); + + // Slice core distances (distances to kth nearest neighbor) + core_distances(dists.data(), min_samples, min_samples, m, core_dists, stream); +} + +// Functor to post-process distances into reachability space +template +struct ReachabilityPostProcess { + DI value_t operator()(value_t value, value_idx row, value_idx col) const + { + return max(core_dists[col], max(core_dists[row], alpha * value)); + } + + const value_t* core_dists; + value_t alpha; +}; + +/** + * Given core distances, Fuses computations of L2 distances between all + * points, projection into mutual reachability space, and k-selection. + * @tparam value_idx + * @tparam value_t + * @param[in] handle raft handle for resource reuse + * @param[out] out_inds output indices array (size m * k) + * @param[out] out_dists output distances array (size m * k) + * @param[in] X input data points (size m * n) + * @param[in] m number of rows in X + * @param[in] n number of columns in X + * @param[in] k neighborhood size (includes self-loop) + * @param[in] core_dists array of core distances (size m) + */ +template +void mutual_reachability_knn_l2(const raft::resources& handle, + value_idx* out_inds, + value_t* out_dists, + const value_t* X, + size_t m, + size_t n, + int k, + value_t* core_dists, + value_t alpha) +{ + // TODO: we are dealing with int32 indices (for compatibility with cuml) + // but the tiled_brute_force_knn code only works with int64. convert + rmm::device_uvector int64_indices(k * m, raft::resource::get_cuda_stream(handle)); + + // Create a functor to postprocess distances into mutual reachability space + // Note that we can't use a lambda for this here, since we get errors like: + // `A type local to a function cannot be used in the template argument of the + // enclosing parent function (and any parent classes) of an extended __device__ + // or __host__ __device__ lambda` + auto epilogue = ReachabilityPostProcess{core_dists, alpha}; + + cuvs::neighbors::detail:: + tiled_brute_force_knn>( + handle, + X, + X, + m, + m, + n, + k, + out_dists, + int64_indices.data(), + cuvs::distance::DistanceType::L2SqrtExpanded, + 2.0, + 0, + 0, + nullptr, + nullptr, + nullptr, + epilogue); + + // convert from current knn's 64-bit to 32-bit. + thrust::transform(raft::resource::get_thrust_policy(handle), + int64_indices.data(), + int64_indices.data() + int64_indices.size(), + out_inds, + [] __device__(int64_t in) -> value_idx { return in; }); +} + +template +void mutual_reachability_graph(const raft::resources& handle, + const value_t* X, + size_t m, + size_t n, + cuvs::distance::DistanceType metric, + int min_samples, + value_t alpha, + value_idx* indptr, + value_t* core_dists, + raft::sparse::COO& out) +{ + RAFT_EXPECTS(metric == cuvs::distance::DistanceType::L2SqrtExpanded, + "Currently only L2 expanded distance is supported"); + + auto stream = raft::resource::get_cuda_stream(handle); + auto exec_policy = raft::resource::get_thrust_policy(handle); + + rmm::device_uvector coo_rows(min_samples * m, stream); + rmm::device_uvector inds(min_samples * m, stream); + rmm::device_uvector dists(min_samples * m, stream); + + // perform knn + compute_knn(handle, X, inds.data(), dists.data(), m, n, X, m, min_samples, metric); + + // Slice core distances (distances to kth nearest neighbor) + core_distances(dists.data(), min_samples, min_samples, m, core_dists, stream); + + /** + * Compute L2 norm + */ + mutual_reachability_knn_l2( + handle, inds.data(), dists.data(), X, m, n, min_samples, core_dists, (value_t)1.0 / alpha); + + // self-loops get max distance + auto coo_rows_counting_itr = thrust::make_counting_iterator(0); + thrust::transform(exec_policy, + coo_rows_counting_itr, + coo_rows_counting_itr + (m * min_samples), + coo_rows.data(), + [min_samples] __device__(value_idx c) -> value_idx { return c / min_samples; }); + + raft::sparse::linalg::symmetrize( + handle, coo_rows.data(), inds.data(), dists.data(), m, m, min_samples * m, out); + + raft::sparse::convert::sorted_coo_to_csr(out.rows(), out.nnz, indptr, m + 1, stream); + + // self-loops get max distance + auto transform_in = + thrust::make_zip_iterator(thrust::make_tuple(out.rows(), out.cols(), out.vals())); + + thrust::transform(exec_policy, + transform_in, + transform_in + out.nnz, + out.vals(), + [=] __device__(const thrust::tuple& tup) { + return thrust::get<0>(tup) == thrust::get<1>(tup) + ? std::numeric_limits::max() + : thrust::get<2>(tup); + }); +} + +} // namespace cuvs::neighbors::detail::reachability diff --git a/cpp/src/neighbors/reachability.cu b/cpp/src/neighbors/reachability.cu new file mode 100644 index 000000000..012864645 --- /dev/null +++ b/cpp/src/neighbors/reachability.cu @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "./detail/reachability.cuh" + +namespace cuvs::neighbors::reachability { + +void mutual_reachability_graph(const raft::resources& handle, + raft::device_matrix_view X, + int min_samples, + raft::device_vector_view indptr, + raft::device_vector_view core_dists, + raft::sparse::COO& out, + cuvs::distance::DistanceType metric, + float alpha) +{ + // TODO: assert core_dists/indptr have right shape + // TODO: add test + cuvs::neighbors::detail::reachability::mutual_reachability_graph( + handle, + X.data_handle(), + X.extent(0), + X.extent(1), + metric, + min_samples, + alpha, + indptr.data_handle(), + core_dists.data_handle(), + out); +} +} // namespace cuvs::neighbors::reachability