Skip to content

Commit

Permalink
Introduce internal_idx_t in ivf_pq::build()
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed Mar 18, 2024
1 parent 78f4164 commit 2bf77db
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ namespace raft::neighbors::ivf_pq::detail {

using namespace raft::spatial::knn::detail; // NOLINT

using internal_idx_t = int64_t; // The default mdspan extent type used internally.

template <uint32_t BlockDim, typename T, typename S>
__launch_bounds__(BlockDim) RAFT_KERNEL copy_warped_kernel(
T* out, uint32_t ld_out, const S* in, uint32_t ld_in, uint32_t n_cols, size_t n_rows)
Expand Down Expand Up @@ -442,15 +444,15 @@ void train_per_subset(raft::resources const& handle,
stream);

// train PQ codebook for this subspace
auto sub_trainset_view = raft::make_device_matrix_view<const float, int64_t>(
auto sub_trainset_view = raft::make_device_matrix_view<const float, internal_idx_t>(
sub_trainset.data(), n_rows, index.pq_len());
auto centers_tmp_view = raft::make_device_matrix_view<float, int64_t>(
auto centers_tmp_view = raft::make_device_matrix_view<float, internal_idx_t>(
pq_centers_tmp.data() + index.pq_book_size() * index.pq_len() * j,
index.pq_book_size(),
index.pq_len());
auto sub_labels_view =
raft::make_device_vector_view<uint32_t, int64_t>(sub_labels.data(), n_rows);
auto cluster_sizes_view = raft::make_device_vector_view<uint32_t, int64_t>(
raft::make_device_vector_view<uint32_t, internal_idx_t>(sub_labels.data(), n_rows);
auto cluster_sizes_view = raft::make_device_vector_view<uint32_t, internal_idx_t>(
pq_cluster_sizes.data(), index.pq_book_size());
raft::cluster::kmeans_balanced_params kmeans_params;
kmeans_params.n_iters = kmeans_n_iters;
Expand Down Expand Up @@ -526,16 +528,16 @@ void train_per_cluster(raft::resources const& handle,
size_t available_rows = size_t(cluster_size) * size_t(index.pq_dim());
auto pq_n_rows = uint32_t(std::min(big_enough, available_rows));
// train PQ codebook for this cluster
auto rot_vectors_view = raft::make_device_matrix_view<const float, int64_t>(
auto rot_vectors_view = raft::make_device_matrix_view<const float, internal_idx_t>(
rot_vectors.data(), pq_n_rows, index.pq_len());
auto centers_tmp_view = raft::make_device_matrix_view<float, int64_t>(
auto centers_tmp_view = raft::make_device_matrix_view<float, internal_idx_t>(
pq_centers_tmp.data() + static_cast<size_t>(index.pq_book_size()) *
static_cast<size_t>(index.pq_len()) * static_cast<size_t>(l),
index.pq_book_size(),
index.pq_len());
auto pq_labels_view =
raft::make_device_vector_view<uint32_t, int64_t>(pq_labels.data(), pq_n_rows);
auto pq_cluster_sizes_view = raft::make_device_vector_view<uint32_t, int64_t>(
raft::make_device_vector_view<uint32_t, internal_idx_t>(pq_labels.data(), pq_n_rows);
auto pq_cluster_sizes_view = raft::make_device_vector_view<uint32_t, internal_idx_t>(
pq_cluster_sizes.data(), index.pq_book_size());
raft::cluster::kmeans_balanced_params kmeans_params;
kmeans_params.n_iters = kmeans_n_iters;
Expand Down Expand Up @@ -1588,11 +1590,11 @@ void extend(raft::resources const& handle,
cudaMemcpyDefault,
stream));
for (const auto& batch : vec_batches) {
auto batch_data_view =
raft::make_device_matrix_view<const T, int64_t>(batch.data(), batch.size(), index->dim());
auto batch_labels_view = raft::make_device_vector_view<uint32_t, int64_t>(
auto batch_data_view = raft::make_device_matrix_view<const T, internal_idx_t>(
batch.data(), batch.size(), index->dim());
auto batch_labels_view = raft::make_device_vector_view<uint32_t, internal_idx_t>(
new_data_labels.data() + batch.offset(), batch.size());
auto centers_view = raft::make_device_matrix_view<const float, int64_t>(
auto centers_view = raft::make_device_matrix_view<const float, internal_idx_t>(
cluster_centers.data(), n_clusters, index->dim());
raft::cluster::kmeans_balanced_params kmeans_params;
kmeans_params.metric = index->metric();
Expand Down Expand Up @@ -1768,10 +1770,10 @@ auto build(raft::resources const& handle,
auto cluster_centers = cluster_centers_buf.data();

// Train balanced hierarchical kmeans clustering
auto trainset_const_view = raft::make_device_matrix_view<const float, int64_t>(
auto trainset_const_view = raft::make_device_matrix_view<const float, internal_idx_t>(
trainset.data(), n_rows_train, index.dim());
auto centers_view =
raft::make_device_matrix_view<float, int64_t>(cluster_centers, index.n_lists(), index.dim());
auto centers_view = raft::make_device_matrix_view<float, internal_idx_t>(
cluster_centers, index.n_lists(), index.dim());
raft::cluster::kmeans_balanced_params kmeans_params;
kmeans_params.n_iters = params.kmeans_n_iters;
kmeans_params.metric = index.metric();
Expand All @@ -1780,10 +1782,10 @@ auto build(raft::resources const& handle,

// Trainset labels are needed for training PQ codebooks
rmm::device_uvector<uint32_t> labels(n_rows_train, stream, device_memory);
auto centers_const_view = raft::make_device_matrix_view<const float, int64_t>(
auto centers_const_view = raft::make_device_matrix_view<const float, internal_idx_t>(
cluster_centers, index.n_lists(), index.dim());
auto labels_view =
raft::make_device_vector_view<uint32_t, int64_t>(labels.data(), n_rows_train);
raft::make_device_vector_view<uint32_t, internal_idx_t>(labels.data(), n_rows_train);
raft::cluster::kmeans_balanced::predict(handle,
kmeans_params,
trainset_const_view,
Expand Down

0 comments on commit 2bf77db

Please sign in to comment.