diff --git a/cpp/src/neighbors/detail/cagra/compute_distance-ext.cuh b/cpp/src/neighbors/detail/cagra/compute_distance-ext.cuh index 1dbc843d0..86de55db6 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance-ext.cuh +++ b/cpp/src/neighbors/detail/cagra/compute_distance-ext.cuh @@ -25,564 +25,26 @@ #pragma once -#include "compute_distance_standard.cuh" -#include "compute_distance_vpq.cuh" +#include "compute_distance_standard.hpp" +#include "compute_distance_vpq.hpp" namespace cuvs::neighbors::cagra::detail { -extern template struct standard_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct standard_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct cagra_q_dataset_descriptor_t; -extern template struct standard_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct standard_descriptor_spec; -extern template struct vpq_descriptor_spec; -extern template struct vpq_descriptor_spec; extern template struct instance_selector< - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec>; + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec>; using descriptor_instances = instance_selector< - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec>; + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec>; template auto dataset_descriptor_init(const cagra::search_params& params, diff --git a/cpp/src/neighbors/detail/cagra/compute_distance.cu b/cpp/src/neighbors/detail/cagra/compute_distance.cu index 5d480f57a..387b4c71b 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance.cu +++ b/cpp/src/neighbors/detail/cagra/compute_distance.cu @@ -27,536 +27,80 @@ namespace cuvs::neighbors::cagra::detail { +using namespace cuvs::distance; + template struct instance_selector< - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec, - standard_descriptor_spec, - standard_descriptor_spec, - vpq_descriptor_spec, - vpq_descriptor_spec>; + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + vpq_descriptor_spec, + vpq_descriptor_spec>; } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_00_generate.py b/cpp/src/neighbors/detail/cagra/compute_distance_00_generate.py index 1f2b24e10..1b0743901 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_00_generate.py +++ b/cpp/src/neighbors/detail/cagra/compute_distance_00_generate.py @@ -43,6 +43,7 @@ namespace cuvs::neighbors::cagra::detail {{ +using namespace cuvs::distance; {content} }} // namespace cuvs::neighbors::cagra::detail @@ -69,7 +70,7 @@ half_uint64=("half", "uint64_t", "float"), ) -metric_prefix = 'cuvs::distance::DistanceType::' +metric_prefix = 'DistanceType::' specs = [] descs = [] @@ -90,17 +91,10 @@ # CAGRA for metric in ['L2Expanded', 'InnerProduct']: path = f"compute_distance_standard_{metric}_{type_path}_dim{mxdim}_t{team}.cu" - includes = '#include "compute_distance_standard.cuh"' + includes = '#include "compute_distance_standard-impl.cuh"' params = f"{metric_prefix}{metric}, {team}, {mxdim}, {data_t}, {idx_t}, {distance_t}" spec = f"standard_descriptor_spec<{params}>" - desc = f"standard_dataset_descriptor_t<{params}>" - content = f""" -template struct {desc}; -template <> -const void* {spec}::init_kernel = reinterpret_cast(&standard_dataset_descriptor_init_kernel<{params}>); -template struct {spec}; -""" - descs.append(desc) + content = f"""template struct {spec};""" specs.append(spec) with open(path, "w") as f: f.write(template.format(includes=includes, content=content)) @@ -112,17 +106,10 @@ for pq_bit in pq_bits: for metric in ['L2Expanded']: path = f"compute_distance_vpq_{metric}_{type_path}_dim{mxdim}_t{team}_{pq_bit}pq_{pq_len}subd_{code_book_t}.cu" - includes = '#include "compute_distance_vpq.cuh"' + includes = '#include "compute_distance_vpq-impl.cuh"' params = f"{metric_prefix}{metric}, {team}, {mxdim}, {pq_bit}, {pq_len}, {code_book_t}, {data_t}, {idx_t}, {distance_t}" spec = f"vpq_descriptor_spec<{params}>" - desc = f"cagra_q_dataset_descriptor_t<{params}>" - content = f""" -template struct {desc}; -template <> -const void* {spec}::init_kernel = reinterpret_cast(&vpq_dataset_descriptor_init_kernel<{params}>); -template struct {spec}; -""" - descs.append(desc) + content = f"""template struct {spec};""" specs.append(spec) with open(path, "w") as f: f.write(template.format(includes=includes, content=content)) @@ -132,12 +119,11 @@ includes = ''' #pragma once -#include "compute_distance_standard.cuh" -#include "compute_distance_vpq.cuh" +#include "compute_distance_standard.hpp" +#include "compute_distance_vpq.hpp" ''' newline = "\n" contents = f''' -{newline.join(map(lambda s: "extern template struct " + s + ";", descs))} {newline.join(map(lambda s: "extern template struct " + s + ";", specs))} extern template struct diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard.cuh b/cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh similarity index 83% rename from cpp/src/neighbors/detail/cagra/compute_distance_standard.cuh rename to cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh index c1d38ead1..7fe5242a9 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_standard.cuh +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include "compute_distance.hpp" +#include "compute_distance_standard.hpp" #include #include @@ -262,57 +262,30 @@ template -struct standard_descriptor_spec : public instance_spec { - using base_type = instance_spec; - using typename base_type::data_type; - using typename base_type::distance_type; - using typename base_type::host_type; - using typename base_type::index_type; - - template - constexpr static inline bool accepts_dataset() - { - return is_strided_dataset_v; - } - - using descriptor_type = +dataset_descriptor_host +standard_descriptor_spec::init_( + const cagra::search_params& params, + const DataT* ptr, + IndexT size, + uint32_t dim, + uint32_t ld, + rmm::cuda_stream_view stream) +{ + using desc_type = standard_dataset_descriptor_t; - static const void* init_kernel; - - template - static auto init(const cagra::search_params& params, - const DatasetT& dataset, - cuvs::distance::DistanceType metric, - rmm::cuda_stream_view stream) -> host_type - { - descriptor_type dd_host{nullptr, - nullptr, - dataset.view().data_handle(), - IndexT(dataset.n_rows()), - dataset.dim(), - dataset.stride()}; - host_type result{dd_host, stream, DatasetBlockDim}; - void* args[] = // NOLINT - {&result.dev_ptr, - &descriptor_type::ptr(dd_host.args), - &dd_host.size, - &dd_host.args.dim, - &descriptor_type::ld(dd_host.args)}; - RAFT_CUDA_TRY(cudaLaunchKernel(init_kernel, 1, 1, args, 0, stream)); - return result; - } + using base_type = typename desc_type::base_type; + desc_type dd_host{nullptr, nullptr, ptr, size, dim, ld}; + host_type result{dd_host, stream, DatasetBlockDim}; - template - static auto priority(const cagra::search_params& params, - const DatasetT& dataset, - cuvs::distance::DistanceType metric) -> double - { - // If explicit team_size is specified and doesn't match the instance, discard it - if (params.team_size != 0 && TeamSize != params.team_size) { return -1.0; } - if (Metric != metric) { return -1.0; } - // Otherwise, favor the closest dataset dimensionality. - return 1.0 / (0.1 + std::abs(double(dataset.dim()) - double(DatasetBlockDim))); - } -}; + standard_dataset_descriptor_init_kernel + <<<1, 1, 0, stream>>>(result.dev_ptr, ptr, size, dim, desc_type::ld(dd_host.args)); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + return result; +} } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard.hpp b/cpp/src/neighbors/detail/cagra/compute_distance_standard.hpp new file mode 100644 index 000000000..df1b77e86 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard.hpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "compute_distance.hpp" + +#include + +#include + +namespace cuvs::neighbors::cagra::detail { + +template +struct standard_descriptor_spec : public instance_spec { + using base_type = instance_spec; + using typename base_type::data_type; + using typename base_type::distance_type; + using typename base_type::host_type; + using typename base_type::index_type; + + template + constexpr static inline bool accepts_dataset() + { + return is_strided_dataset_v; + } + + template + static auto init(const cagra::search_params& params, + const DatasetT& dataset, + cuvs::distance::DistanceType metric, + rmm::cuda_stream_view stream) -> host_type + { + return init_(params, + dataset.view().data_handle(), + IndexT(dataset.n_rows()), + dataset.dim(), + dataset.stride(), + stream); + } + + template + static auto priority(const cagra::search_params& params, + const DatasetT& dataset, + cuvs::distance::DistanceType metric) -> double + { + // If explicit team_size is specified and doesn't match the instance, discard it + if (params.team_size != 0 && TeamSize != params.team_size) { return -1.0; } + if (Metric != metric) { return -1.0; } + // Otherwise, favor the closest dataset dimensionality. + return 1.0 / (0.1 + std::abs(double(dataset.dim()) - double(DatasetBlockDim))); + } + + private: + static dataset_descriptor_host init_(const cagra::search_params& params, + const DataT* ptr, + IndexT size, + uint32_t dim, + uint32_t ld, + rmm::cuda_stream_view stream); +}; + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_float_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_float_uint32_dim128_t8.cu index bc1900856..af5e89a76 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_float_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_float_uint32_dim128_t8.cu @@ -23,31 +23,12 @@ * */ -#include "compute_distance_standard.cuh" +#include "compute_distance_standard-impl.cuh" namespace cuvs::neighbors::cagra::detail { -template struct standard_dataset_descriptor_t; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; +using namespace cuvs::distance; +template struct standard_descriptor_spec; } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_half_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_half_uint32_dim256_t16.cu index f1ace30cc..cfad79f3a 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_half_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_half_uint32_dim256_t16.cu @@ -23,31 +23,12 @@ * */ -#include "compute_distance_standard.cuh" +#include "compute_distance_standard-impl.cuh" namespace cuvs::neighbors::cagra::detail { -template struct standard_dataset_descriptor_t; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; +using namespace cuvs::distance; +template struct standard_descriptor_spec; } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_half_uint64_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_half_uint64_dim256_t16.cu index 4528426c7..32a18ff3e 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_half_uint64_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_half_uint64_dim256_t16.cu @@ -23,31 +23,12 @@ * */ -#include "compute_distance_standard.cuh" +#include "compute_distance_standard-impl.cuh" namespace cuvs::neighbors::cagra::detail { -template struct standard_dataset_descriptor_t; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; +using namespace cuvs::distance; +template struct standard_descriptor_spec; } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_float_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_float_uint32_dim256_t16.cu index cdb315bac..7d1206c37 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_float_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_float_uint32_dim256_t16.cu @@ -23,35 +23,11 @@ * */ -#include "compute_distance_standard.cuh" +#include "compute_distance_standard-impl.cuh" namespace cuvs::neighbors::cagra::detail { -template struct standard_dataset_descriptor_t; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; +using namespace cuvs::distance; +template struct standard_descriptor_spec; } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_float_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_float_uint32_dim512_t32.cu index 49053a2d6..251316b2c 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_float_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_float_uint32_dim512_t32.cu @@ -23,35 +23,11 @@ * */ -#include "compute_distance_standard.cuh" +#include "compute_distance_standard-impl.cuh" namespace cuvs::neighbors::cagra::detail { -template struct standard_dataset_descriptor_t; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; +using namespace cuvs::distance; +template struct standard_descriptor_spec; } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_float_uint64_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_float_uint64_dim128_t8.cu index 5a534718b..7a8c4059c 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_float_uint64_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_float_uint64_dim128_t8.cu @@ -23,35 +23,11 @@ * */ -#include "compute_distance_standard.cuh" +#include "compute_distance_standard-impl.cuh" namespace cuvs::neighbors::cagra::detail { -template struct standard_dataset_descriptor_t; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; +using namespace cuvs::distance; +template struct standard_descriptor_spec; } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_float_uint64_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_float_uint64_dim256_t16.cu index 7e85fa349..fcc65a48e 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_float_uint64_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_float_uint64_dim256_t16.cu @@ -23,35 +23,11 @@ * */ -#include "compute_distance_standard.cuh" +#include "compute_distance_standard-impl.cuh" namespace cuvs::neighbors::cagra::detail { -template struct standard_dataset_descriptor_t; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; +using namespace cuvs::distance; +template struct standard_descriptor_spec; } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_float_uint64_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_float_uint64_dim512_t32.cu index 4bc254679..833dac9c4 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_float_uint64_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_float_uint64_dim512_t32.cu @@ -23,35 +23,11 @@ * */ -#include "compute_distance_standard.cuh" +#include "compute_distance_standard-impl.cuh" namespace cuvs::neighbors::cagra::detail { -template struct standard_dataset_descriptor_t; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; +using namespace cuvs::distance; +template struct standard_descriptor_spec; } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint32_dim128_t8.cu index c0fe52caf..e3870df40 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint32_dim128_t8.cu @@ -23,35 +23,11 @@ * */ -#include "compute_distance_standard.cuh" +#include "compute_distance_standard-impl.cuh" namespace cuvs::neighbors::cagra::detail { -template struct standard_dataset_descriptor_t; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; +using namespace cuvs::distance; +template struct standard_descriptor_spec; } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint32_dim256_t16.cu index b585e1f80..1253d7cd4 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint32_dim256_t16.cu @@ -23,35 +23,11 @@ * */ -#include "compute_distance_standard.cuh" +#include "compute_distance_standard-impl.cuh" namespace cuvs::neighbors::cagra::detail { -template struct standard_dataset_descriptor_t; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; +using namespace cuvs::distance; +template struct standard_descriptor_spec; } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint32_dim512_t32.cu index 91de967e8..792532c2c 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint32_dim512_t32.cu @@ -23,35 +23,11 @@ * */ -#include "compute_distance_standard.cuh" +#include "compute_distance_standard-impl.cuh" namespace cuvs::neighbors::cagra::detail { -template struct standard_dataset_descriptor_t; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; +using namespace cuvs::distance; +template struct standard_descriptor_spec; } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint64_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint64_dim128_t8.cu index b77b84793..b3a466f46 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint64_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint64_dim128_t8.cu @@ -23,35 +23,11 @@ * */ -#include "compute_distance_standard.cuh" +#include "compute_distance_standard-impl.cuh" namespace cuvs::neighbors::cagra::detail { -template struct standard_dataset_descriptor_t; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; +using namespace cuvs::distance; +template struct standard_descriptor_spec; } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint64_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint64_dim256_t16.cu index 7ce86c034..a11701e5a 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint64_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint64_dim256_t16.cu @@ -23,35 +23,11 @@ * */ -#include "compute_distance_standard.cuh" +#include "compute_distance_standard-impl.cuh" namespace cuvs::neighbors::cagra::detail { -template struct standard_dataset_descriptor_t; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; +using namespace cuvs::distance; +template struct standard_descriptor_spec; } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint64_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint64_dim512_t32.cu index 507d709eb..9ed0a32ee 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint64_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint64_dim512_t32.cu @@ -23,35 +23,11 @@ * */ -#include "compute_distance_standard.cuh" +#include "compute_distance_standard-impl.cuh" namespace cuvs::neighbors::cagra::detail { -template struct standard_dataset_descriptor_t; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; +using namespace cuvs::distance; +template struct standard_descriptor_spec; } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_int8_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_int8_uint32_dim128_t8.cu index c5c7a7b4c..c9c960cf9 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_int8_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_int8_uint32_dim128_t8.cu @@ -23,35 +23,11 @@ * */ -#include "compute_distance_standard.cuh" +#include "compute_distance_standard-impl.cuh" namespace cuvs::neighbors::cagra::detail { -template struct standard_dataset_descriptor_t; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; +using namespace cuvs::distance; +template struct standard_descriptor_spec; } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_int8_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_int8_uint32_dim256_t16.cu index 8d237f58b..d7a12804b 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_int8_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_int8_uint32_dim256_t16.cu @@ -23,31 +23,12 @@ * */ -#include "compute_distance_standard.cuh" +#include "compute_distance_standard-impl.cuh" namespace cuvs::neighbors::cagra::detail { -template struct standard_dataset_descriptor_t; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec; -template <> -const void* standard_descriptor_spec::init_kernel = - reinterpret_cast( - &standard_dataset_descriptor_init_kernel); -template struct standard_descriptor_spec #include @@ -30,13 +30,13 @@ template struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t { using base_type = dataset_descriptor_base_t; - using CODE_BOOK_T = CodeBookT; + using CODE_BOOK_T = CodebookT; using QUERY_T = half; using base_type::args; using base_type::extra_ptr3; @@ -119,7 +119,7 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t __launch_bounds__(1, 1) __global__ void vpq_dataset_descriptor_init_kernel(dataset_descriptor_base_t* out, const std::uint8_t* encoded_dataset_ptr, - std::uint32_t encoded_dataset_dim, - std::uint32_t n_subspace, - const CodeBookT* vq_code_book_ptr, - const CodeBookT* pq_code_book_ptr, - std::size_t size, - std::uint32_t dim) + uint32_t encoded_dataset_dim, + uint32_t n_subspace, + const CodebookT* vq_code_book_ptr, + const CodebookT* pq_code_book_ptr, + IndexT size, + uint32_t dim) { using desc_type = cagra_q_dataset_descriptor_t; @@ -404,81 +404,64 @@ template -struct vpq_descriptor_spec : public instance_spec { - using base_type = instance_spec; - using typename base_type::data_type; - using typename base_type::distance_type; - using typename base_type::host_type; - using typename base_type::index_type; - - template - constexpr static inline auto accepts_dataset() - -> std::enable_if_t, bool> - { - return std::is_same_v; - } - - template - constexpr static inline auto accepts_dataset() - -> std::enable_if_t, bool> - { - return false; - } - - using descriptor_type = cagra_q_dataset_descriptor_t; - static const void* init_kernel; - - template - static auto init(const cagra::search_params& params, - const DatasetT& dataset, - cuvs::distance::DistanceType metric, - rmm::cuda_stream_view stream) -> host_type - { - descriptor_type dd_host{nullptr, - nullptr, - dataset.data.data_handle(), - dataset.encoded_row_length(), - dataset.pq_dim(), - dataset.vq_code_book.data_handle(), - dataset.pq_code_book.data_handle(), - IndexT(dataset.n_rows()), - dataset.dim()}; - host_type result{dd_host, stream, DatasetBlockDim}; - void* args[] = // NOLINT - {&result.dev_ptr, - &descriptor_type::encoded_dataset_ptr(dd_host.args), - &descriptor_type::encoded_dataset_dim(dd_host.args), - &descriptor_type::n_subspace(dd_host.args), - &descriptor_type::vq_code_book_ptr(dd_host.args), - &dd_host.pq_code_book_ptr(), - &dd_host.size, - &dd_host.args.dim}; - RAFT_CUDA_TRY(cudaLaunchKernel(init_kernel, 1, 1, args, 0, stream)); - return result; - } +dataset_descriptor_host +vpq_descriptor_spec::init_(const cagra::search_params& params, + const std::uint8_t* encoded_dataset_ptr, + uint32_t encoded_dataset_dim, + uint32_t n_subspace, + const CodebookT* vq_code_book_ptr, + const CodebookT* pq_code_book_ptr, + IndexT size, + uint32_t dim, + rmm::cuda_stream_view stream) +{ + using desc_type = cagra_q_dataset_descriptor_t; + using base_type = typename desc_type::base_type; - template - static auto priority(const cagra::search_params& params, - const DatasetT& dataset, - cuvs::distance::DistanceType metric) -> double - { - // If explicit team_size is specified and doesn't match the instance, discard it - if (params.team_size != 0 && TeamSize != params.team_size) { return -1.0; } - if (cuvs::distance::DistanceType::L2Expanded != metric) { return -1.0; } - // Match codebook params - if (dataset.pq_bits() != PqBits) { return -1.0; } - if (dataset.pq_len() != PqLen) { return -1.0; } - // Otherwise, favor the closest dataset dimensionality. - return 1.0 / (0.1 + std::abs(double(dataset.dim()) - double(DatasetBlockDim))); - } -}; + desc_type dd_host{nullptr, + nullptr, + encoded_dataset_ptr, + encoded_dataset_dim, + n_subspace, + vq_code_book_ptr, + pq_code_book_ptr, + size, + dim}; + host_type result{dd_host, stream, DatasetBlockDim}; + vpq_dataset_descriptor_init_kernel<<<1, 1, 0, stream>>>(result.dev_ptr, + encoded_dataset_ptr, + encoded_dataset_dim, + n_subspace, + vq_code_book_ptr, + pq_code_book_ptr, + size, + dim); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + return result; +} } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp b/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp new file mode 100644 index 000000000..9d5b0b6c0 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "compute_distance.hpp" + +#include + +#include + +namespace cuvs::neighbors::cagra::detail { + +template +struct vpq_descriptor_spec : public instance_spec { + using base_type = instance_spec; + using typename base_type::data_type; + using typename base_type::distance_type; + using typename base_type::host_type; + using typename base_type::index_type; + + template + constexpr static inline auto accepts_dataset() + -> std::enable_if_t, bool> + { + return std::is_same_v; + } + + template + constexpr static inline auto accepts_dataset() + -> std::enable_if_t, bool> + { + return false; + } + + template + static auto init(const cagra::search_params& params, + const DatasetT& dataset, + cuvs::distance::DistanceType metric, + rmm::cuda_stream_view stream) -> host_type + { + return init_(params, + dataset.data.data_handle(), + dataset.encoded_row_length(), + dataset.pq_dim(), + dataset.vq_code_book.data_handle(), + dataset.pq_code_book.data_handle(), + IndexT(dataset.n_rows()), + dataset.dim(), + stream); + } + + template + static auto priority(const cagra::search_params& params, + const DatasetT& dataset, + cuvs::distance::DistanceType metric) -> double + { + // If explicit team_size is specified and doesn't match the instance, discard it + if (params.team_size != 0 && TeamSize != params.team_size) { return -1.0; } + if (cuvs::distance::DistanceType::L2Expanded != metric) { return -1.0; } + // Match codebook params + if (dataset.pq_bits() != PqBits) { return -1.0; } + if (dataset.pq_len() != PqLen) { return -1.0; } + // Otherwise, favor the closest dataset dimensionality. + return 1.0 / (0.1 + std::abs(double(dataset.dim()) - double(DatasetBlockDim))); + } + + private: + static dataset_descriptor_host init_( + const cagra::search_params& params, + const std::uint8_t* encoded_dataset_ptr, + uint32_t encoded_dataset_dim, + uint32_t n_subspace, + const CodebookT* vq_code_book_ptr, + const CodebookT* pq_code_book_ptr, + IndexT size, + uint32_t dim, + rmm::cuda_stream_view stream); +}; + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq_L2Expanded_float_uint32_dim128_t8_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/compute_distance_vpq_L2Expanded_float_uint32_dim128_t8_8pq_2subd_half.cu index 7abc27bda..a56a5a9df 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq_L2Expanded_float_uint32_dim128_t8_8pq_2subd_half.cu +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq_L2Expanded_float_uint32_dim128_t8_8pq_2subd_half.cu @@ -23,40 +23,12 @@ * */ -#include "compute_distance_vpq.cuh" +#include "compute_distance_vpq-impl.cuh" namespace cuvs::neighbors::cagra::detail { -template struct cagra_q_dataset_descriptor_t; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec; -template <> -const void* vpq_descriptor_spec::init_kernel = - reinterpret_cast( - &vpq_dataset_descriptor_init_kernel); -template struct vpq_descriptor_spec