Skip to content

Commit

Permalink
Add function for calculating the mutual_reachability_graph
Browse files Browse the repository at this point in the history
  • Loading branch information
benfred committed Sep 11, 2024
1 parent c40abae commit cc7d253
Show file tree
Hide file tree
Showing 5 changed files with 430 additions and 4 deletions.
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ add_library(
src/neighbors/nn_descent_float.cu
src/neighbors/nn_descent_int8.cu
src/neighbors/nn_descent_uint8.cu
src/neighbors/reachability.cu
src/neighbors/refine/detail/refine_device_float_float.cu
src/neighbors/refine/detail/refine_device_half_float.cu
src/neighbors/refine/detail/refine_device_int8_t_float.cu
Expand Down
79 changes: 79 additions & 0 deletions cpp/include/cuvs/neighbors/reachability.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <raft/sparse/coo.hpp>

#include <cuvs/distance/distance.hpp>

namespace cuvs::neighbors::reachability {

/**
* @defgroup reachability_cpp Mutual Reachability
* @{
*/
/**
* Constructs a mutual reachability graph, which is a k-nearest neighbors
* graph projected into mutual reachability space using the following
* function for each data point, where core_distance is the distance
* to the kth neighbor: max(core_distance(a), core_distance(b), d(a, b))
*
* Unfortunately, points in the tails of the pdf (e.g. in sparse regions
* of the space) can have very large neighborhoods, which will impact
* nearby neighborhoods. Because of this, it's possible that the
* radius for points in the main mass, which might have a very small
* radius initially, to expand very large. As a result, the initial
* knn which was used to compute the core distances may no longer
* capture the actual neighborhoods after projection into mutual
* reachability space.
*
* For the experimental version, we execute the knn twice- once
* to compute the radii (core distances) and again to capture
* the final neighborhoods. Future iterations of this algorithm
* will work improve upon this "exact" version, by using
* more specialized data structures, such as space-partitioning
* structures. It has also been shown that approximate nearest
* neighbors can yield reasonable neighborhoods as the
* data sizes increase.
*
* @param[in] handle raft handle for resource reuse
* @param[in] X input data points (size m * n)
* @param[in] min_samples this neighborhood will be selected for core distances
* @param[out] indptr CSR indptr of output knn graph (size m + 1)
* @param[out] core_dists output core distances array (size m)
* @param[out] out COO object, uninitialized on entry, on exit it stores the
* (symmetrized) maximum reachability distance for the k nearest
* neighbors.
* @param[in] metric distance metric to use, default Euclidean
* @param[in] alpha weight applied when internal distance is chosen for
* mutual reachability (value of 1.0 disables the weighting)
*/
void mutual_reachability_graph(
const raft::resources& handle,
raft::device_matrix_view<const float, int64_t, raft::row_major> X,
int min_samples,
raft::device_vector_view<int> indptr,
raft::device_vector_view<float> core_dists,
raft::sparse::COO<float, int>& out,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2SqrtExpanded,
float alpha = 1.0);
/**
* @}
*/
} // namespace cuvs::neighbors::reachability
27 changes: 23 additions & 4 deletions cpp/src/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ namespace cuvs::neighbors::detail {
* Calculates brute force knn, using a fixed memory budget
* by tiling over both the rows and columns of pairwise_distances
*/
template <typename ElementType = float, typename IndexType = int64_t, typename DistanceT = float>
template <typename ElementType = float,
typename IndexType = int64_t,
typename DistanceT = float,
typename DistanceEpilogue = raft::identity_op>
void tiled_brute_force_knn(const raft::resources& handle,
const ElementType* search, // size (m ,d)
const ElementType* index, // size (n ,d)
Expand All @@ -78,7 +81,8 @@ void tiled_brute_force_knn(const raft::resources& handle,
size_t max_col_tile_size = 0,
const DistanceT* precomputed_index_norms = nullptr,
const DistanceT* precomputed_search_norms = nullptr,
const uint32_t* filter_bitmap = nullptr)
const uint32_t* filter_bitmap = nullptr,
DistanceEpilogue distance_epilogue = raft::identity_op())
{
// Figure out the number of rows/cols to tile for
size_t tile_rows = 0;
Expand Down Expand Up @@ -209,7 +213,8 @@ void tiled_brute_force_knn(const raft::resources& handle,
IndexType col = j + (idx % current_centroid_size);

cuvs::distance::detail::ops::l2_exp_cutlass_op<DistanceT, DistanceT> l2_op(sqrt);
return l2_op(row_norms[row], col_norms[col], dist[idx]);
auto val = l2_op(row_norms[row], col_norms[col], dist[idx]);
return distance_epilogue(val, row, col);
});
} else if (metric == cuvs::distance::DistanceType::CosineExpanded) {
auto row_norms = precomputed_search_norms ? precomputed_search_norms : search_norms.data();
Expand All @@ -223,8 +228,22 @@ void tiled_brute_force_knn(const raft::resources& handle,
IndexType row = i + (idx / current_centroid_size);
IndexType col = j + (idx % current_centroid_size);
auto val = DistanceT(1.0) - dist[idx] / DistanceT(row_norms[row] * col_norms[col]);
return val;
return distance_epilogue(val, row, col);
});
} else {
// if we're not l2 distance, and we have a distance epilogue - run it now
if constexpr (!std::is_same_v<DistanceEpilogue, raft::identity_op>) {
auto distances_ptr = temp_distances.data();
raft::linalg::map_offset(
handle,
raft::make_device_vector_view(temp_distances.data(),
current_query_size * current_centroid_size),
[=] __device__(size_t idx) {
IndexType row = i + (idx / current_centroid_size);
IndexType col = j + (idx % current_centroid_size);
return distance_epilogue(distances_ptr[idx], row, col);
});
}
}

if (filter_bitmap != nullptr) {
Expand Down
Loading

0 comments on commit cc7d253

Please sign in to comment.