diff --git a/nucliadb_vectors/src/data_point/disk_hnsw.rs b/nucliadb_vectors/src/data_point/disk_hnsw.rs index 59e8db33bf..4c0a9de634 100644 --- a/nucliadb_vectors/src/data_point/disk_hnsw.rs +++ b/nucliadb_vectors/src/data_point/disk_hnsw.rs @@ -112,7 +112,7 @@ impl<'a> Iterator for EdgeIter<'a> { let edge = f32_from_le_bytes(&buf[crnt..(crnt + EDGE_LEN)]); crnt += EDGE_LEN; self.crnt = crnt; - Some((Address(node), Edge { dist: edge })) + Some((Address(node), edge)) } } } @@ -138,7 +138,7 @@ impl DiskHnsw { length += USIZE_LEN; for (cnx, edge) in hnsw.get_layer(layer).get_out_edges(node) { buf.write_all(&cnx.0.to_le_bytes())?; - buf.write_all(&edge.dist.to_le_bytes())?; + buf.write_all(&edge.to_le_bytes())?; length += CNX_LEN; } } @@ -242,9 +242,9 @@ mod tests { fn hnsw_test() { let no_nodes = 3; let cnx0 = vec![ - vec![(Address(1), Edge { dist: 1.0 })], - vec![(Address(2), Edge { dist: 2.0 })], - vec![(Address(3), Edge { dist: 3.0 })], + vec![(Address(1), 1.0)], + vec![(Address(2), 2.0)], + vec![(Address(3), 3.0)], ]; let layer0 = RAMLayer { out: cnx0 @@ -253,10 +253,7 @@ mod tests { .map(|(i, c)| (Address(i), c.clone())) .collect(), }; - let cnx1 = vec![ - vec![(Address(1), Edge { dist: 4.0 })], - vec![(Address(2), Edge { dist: 5.0 })], - ]; + let cnx1 = vec![vec![(Address(1), 4.0)], vec![(Address(2), 5.0)]]; let layer1 = RAMLayer { out: cnx1 .iter() @@ -264,7 +261,7 @@ mod tests { .map(|(i, c)| (Address(i), c.clone())) .collect(), }; - let cnx2 = vec![vec![(Address(1), Edge { dist: 6.0 })]]; + let cnx2 = vec![vec![(Address(1), 6.0)]]; let layer2 = RAMLayer { out: cnx2 .iter() diff --git a/nucliadb_vectors/src/data_point/ops_hnsw.rs b/nucliadb_vectors/src/data_point/ops_hnsw.rs index ae1998b775..4756ca7a5b 100644 --- a/nucliadb_vectors/src/data_point/ops_hnsw.rs +++ b/nucliadb_vectors/src/data_point/ops_hnsw.rs @@ -116,7 +116,7 @@ impl<'a, DR: DataRetriever> HnswOps<'a, DR> { k_neighbours: usize, mut candidates: Vec<(Address, Edge)>, ) -> Vec<(Address, Edge)> { - candidates.sort_unstable_by_key(|(n, d)| std::cmp::Reverse(Cnx(*n, d.dist))); + candidates.sort_unstable_by_key(|(n, d)| std::cmp::Reverse(Cnx(*n, *d))); candidates.dedup_by_key(|(addr, _)| *addr); candidates.truncate(k_neighbours); candidates @@ -152,7 +152,7 @@ impl<'a, DR: DataRetriever> HnswOps<'a, DR> { } Some((down, _)) => { let mut sorted_out: Vec<_> = layer.get_out_edges(down).collect(); - sorted_out.sort_by(|a, b| b.1.dist.total_cmp(&a.1.dist)); + sorted_out.sort_by(|a, b| b.1.total_cmp(&a.1)); sorted_out.into_iter().for_each(|(new_candidate, _)| { if !visited_nodes.contains(&new_candidate) { candidates.push_back(new_candidate); @@ -216,14 +216,15 @@ impl<'a, DR: DataRetriever> HnswOps<'a, DR> { ) -> Vec
{ use params::*; let neighbours = self.layer_search::<&RAMLayer>(x, layer, ef_construction(), entry_points); + let neighbours = self.select_neighbours_heuristic(m_max(), neighbours); let mut needs_repair = HashSet::new(); let mut result = Vec::with_capacity(neighbours.len()); layer.add_node(x); for (y, dist) in neighbours.iter().copied() { result.push(y); - layer.add_edge(x, Edge { dist }, y); - layer.add_edge(y, Edge { dist }, x); - if layer.no_out_edges(y) > 2 * m_max() { + layer.add_edge(x, dist, y); + layer.add_edge(y, dist, x); + if layer.no_out_edges(y) > m_max() { needs_repair.insert(y); } } diff --git a/nucliadb_vectors/src/data_point/params.rs b/nucliadb_vectors/src/data_point/params.rs index 5f83d8c800..dd0813f677 100644 --- a/nucliadb_vectors/src/data_point/params.rs +++ b/nucliadb_vectors/src/data_point/params.rs @@ -28,7 +28,7 @@ pub fn level_factor() -> f64 { /// Upper limit to the number of out-edges a embedding can have. pub const fn m_max() -> usize { - 30 + 60 } /// Number of bi-directional links created for every new element. diff --git a/nucliadb_vectors/src/data_point/ram_hnsw.rs b/nucliadb_vectors/src/data_point/ram_hnsw.rs index 8f419563a5..5828c7923b 100644 --- a/nucliadb_vectors/src/data_point/ram_hnsw.rs +++ b/nucliadb_vectors/src/data_point/ram_hnsw.rs @@ -33,10 +33,7 @@ pub struct EntryPoint { pub layer: usize, } -#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Serialize, Deserialize)] -pub struct Edge { - pub dist: f32, -} +pub type Edge = f32; #[derive(Default, Clone)] pub struct RAMLayer {