forked from facebookresearch/faiss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
IndexReplicas.cpp
123 lines (101 loc) · 3.8 KB
/
IndexReplicas.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <faiss/IndexReplicas.h>
#include <faiss/impl/FaissAssert.h>
namespace faiss {
template <typename IndexT>
IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(bool threaded)
: ThreadedIndex<IndexT>(threaded) {
}
template <typename IndexT>
IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(idx_t d, bool threaded)
: ThreadedIndex<IndexT>(d, threaded) {
}
template <typename IndexT>
IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(int d, bool threaded)
: ThreadedIndex<IndexT>(d, threaded) {
}
template <typename IndexT>
void
IndexReplicasTemplate<IndexT>::onAfterAddIndex(IndexT* index) {
// Make sure that the parameters are the same for all prior indices, unless
// we're the first index to be added
if (this->count() > 0 && this->at(0) != index) {
auto existing = this->at(0);
FAISS_THROW_IF_NOT_FMT(index->ntotal == existing->ntotal,
"IndexReplicas: newly added index does "
"not have same number of vectors as prior index; "
"prior index has %ld vectors, new index has %ld",
existing->ntotal, index->ntotal);
FAISS_THROW_IF_NOT_MSG(index->is_trained == existing->is_trained,
"IndexReplicas: newly added index does "
"not have same train status as prior index");
} else {
// Set our parameters based on the first index we're adding
// (dimension is handled in ThreadedIndex)
this->ntotal = index->ntotal;
this->verbose = index->verbose;
this->is_trained = index->is_trained;
this->metric_type = index->metric_type;
}
}
template <typename IndexT>
void
IndexReplicasTemplate<IndexT>::train(idx_t n, const component_t* x) {
this->runOnIndex([n, x](int, IndexT* index){ index->train(n, x); });
}
template <typename IndexT>
void
IndexReplicasTemplate<IndexT>::add(idx_t n, const component_t* x) {
this->runOnIndex([n, x](int, IndexT* index){ index->add(n, x); });
this->ntotal += n;
}
template <typename IndexT>
void
IndexReplicasTemplate<IndexT>::reconstruct(idx_t n, component_t* x) const {
FAISS_THROW_IF_NOT_MSG(this->count() > 0, "no replicas in index");
// Just pass to the first replica
this->at(0)->reconstruct(n, x);
}
template <typename IndexT>
void
IndexReplicasTemplate<IndexT>::search(idx_t n,
const component_t* x,
idx_t k,
distance_t* distances,
idx_t* labels) const {
FAISS_THROW_IF_NOT_MSG(this->count() > 0, "no replicas in index");
if (n == 0) {
return;
}
auto dim = this->d;
size_t componentsPerVec =
sizeof(component_t) == 1 ? (dim + 7) / 8 : dim;
// Partition the query by the number of indices we have
faiss::Index::idx_t queriesPerIndex =
(faiss::Index::idx_t) (n + this->count() - 1) /
(faiss::Index::idx_t) this->count();
FAISS_ASSERT(n / queriesPerIndex <= this->count());
auto fn =
[queriesPerIndex, componentsPerVec,
n, x, k, distances, labels](int i, const IndexT* index) {
faiss::Index::idx_t base = (faiss::Index::idx_t) i * queriesPerIndex;
if (base < n) {
auto numForIndex = std::min(queriesPerIndex, n - base);
index->search(numForIndex,
x + base * componentsPerVec,
k,
distances + base * k,
labels + base * k);
}
};
this->runOnIndex(fn);
}
// explicit instantiations
template struct IndexReplicasTemplate<Index>;
template struct IndexReplicasTemplate<IndexBinary>;
} // namespace