Skip to content

Commit

Permalink
Merge branch 'canonical-value-tests' into read-config-files
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Oct 10, 2023
2 parents 139e560 + b01c93d commit d2f61fe
Showing 1 changed file with 81 additions and 11 deletions.
92 changes: 81 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,9 @@ impl FlagEmbedding {
let mut tokenizer =
tokenizers::Tokenizer::from_file(tokenizer_path).map_err(|e| anyhow::Error::msg(e))?;

let max_length =
max_length.min(tokenizer_config["model_max_length"].as_u64().unwrap() as usize);
//For BGEBaseSmall, the model_max_length value is set to 1000000000000000019884624838656. Which fits in a f64
let model_max_length = tokenizer_config["model_max_length"].as_f64().unwrap();
let max_length = max_length.min(model_max_length as usize);
let pad_id = config["pad_token_id"]
.as_u64()
.expect("couldn't parse pad_token_id") as u32;
Expand Down Expand Up @@ -500,27 +501,96 @@ fn get_embeddings(data: &[f32], dimensions: &[usize]) -> Vec<Embedding> {
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f32 = 1e-4;

#[test]
fn test_mle5_large() {
fn test_bgesmall() {
let model: FlagEmbedding = FlagEmbedding::try_new(InitOptions {
model_name: EmbeddingModel::BGESmallEN,
..Default::default()
})
.unwrap();

let expected: Vec<f32> = vec![
-0.02313, -0.02552, 0.017357, -0.06393, -0.00061, 0.022123, -0.01472, 0.039255,
0.034447, 0.004598,
];
let documents = vec!["hello world"];

// Generate embeddings with the default batch size, 256
let embeddings = model.embed(documents, None).unwrap();

for (i, v) in expected.into_iter().enumerate() {
let difference = (v - embeddings[0][i]).abs();
assert!(difference < EPSILON, "Difference: {}", difference)
}
}

#[test]
fn test_bgebase() {
let model: FlagEmbedding = FlagEmbedding::try_new(InitOptions {
model_name: EmbeddingModel::BGEBaseEN,
..Default::default()
})
.unwrap();

let expected: Vec<f32> = vec![
0.0114, 0.03722, 0.02941, 0.0123, 0.03451, 0.00876, 0.02356, 0.05414, -0.0294, -0.0547,
];
let documents = vec!["hello world"];

// Generate embeddings with the default batch size, 256
let embeddings = model.embed(documents, None).unwrap();

for (i, v) in expected.into_iter().enumerate() {
let difference = (v - embeddings[0][i]).abs();
assert!(difference < EPSILON, "Difference: {}", difference)
}
}

#[test]
fn test_allminilm() {
let model: FlagEmbedding = FlagEmbedding::try_new(InitOptions {
model_name: EmbeddingModel::AllMiniLML6V2,
show_download_message: false,
..Default::default()
})
.unwrap();

let documents = vec![
"passage: Hello, World!",
"query: Hello, World!",
"passage: This is an example passage.",
// You can leave out the prefix but it's recommended
"fastembed-rs is licensed under MIT",
let expected: Vec<f32> = vec![
0.02591, 0.00573, 0.01147, 0.03796, -0.0232, -0.0549, 0.01404, -0.0107, -0.0244,
-0.01822,
];
let documents = vec!["hello world"];

// Generate embeddings with the default batch size, 256
let embeddings = model.embed(documents, None).unwrap();

println!("Embeddings length: {}", embeddings.len());
for (i, v) in expected.into_iter().enumerate() {
let difference = (v - embeddings[0][i]).abs();
assert!(difference < EPSILON, "Difference: {}", difference)
}
}

#[test]
fn test_mle5large() {
let model: FlagEmbedding = FlagEmbedding::try_new(InitOptions {
model_name: EmbeddingModel::MLE5Large,
..Default::default()
})
.unwrap();

let expected: Vec<f32> = vec![
0.00961, 0.00443, 0.00658, -0.03532, 0.00703, -0.02878, -0.03671, 0.03482, 0.06343,
-0.04731,
];
let documents = vec!["hello world"];

// Generate embeddings with the default batch size, 256
let embeddings = model.embed(documents, None).unwrap();

for (i, v) in expected.into_iter().enumerate() {
let difference = (v - embeddings[0][i]).abs();
assert!(difference < EPSILON, "Difference: {}", difference)
}
}
}

0 comments on commit d2f61fe

Please sign in to comment.