From cf0a512ef8546229469e56b6efcfdd2b030916cb Mon Sep 17 00:00:00 2001 From: Anush Date: Tue, 10 Oct 2023 21:29:57 +0530 Subject: [PATCH] feat: Reading config from model files (#2) * feat: Reading config from model files * test: canonical value tests * chore: refactor tests * deps: remove tokenizers default feats * ci: cache outputs * fix: ci cache * fix: cache ci name * ci: check cache before tests --- .github/workflows/test.yml | 12 ++ Cargo.lock | 255 ++----------------------------------- Cargo.toml | 5 +- src/lib.rs | 166 ++++++++++++++++-------- 4 files changed, 138 insertions(+), 300 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6ef41c0..8717f42 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,5 +14,17 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 + + - uses: actions/cache@v3 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + target/ + local_cache/ + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + - name: Run tests run: cargo test \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 3394da1..e5a7233 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,54 +17,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "anstream" -version = "0.6.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ab91ebe16eb252986481c5b62f6098f3b698a45e34b5b98200cf20dd2484a44" -dependencies = [ - "anstyle", - "anstyle-parse", - "anstyle-query", - "anstyle-wincon", - "colorchoice", - "utf8parse", -] - -[[package]] -name = "anstyle" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" - -[[package]] -name = "anstyle-parse" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "317b9a89c1868f5ea6ff1d9539a69f45dffc21ce321ac1fd1160dfa48c8e2140" -dependencies = [ - "utf8parse", -] - -[[package]] -name = "anstyle-query" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ca11d4be1bab0c8bc8734a9aa7bf4ee8316d462a08c6ac5052f888fef5b494b" -dependencies = [ - "windows-sys 0.48.0", -] - -[[package]] -name = "anstyle-wincon" -version = "3.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0699d10d2f4d628a98ee7b57b289abbc98ff3bad977cb3152709d4bf2330628" -dependencies = [ - "anstyle", - "windows-sys 0.48.0", -] - [[package]] name = "anyhow" version = "1.0.75" @@ -122,65 +74,6 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" -[[package]] -name = "clap" -version = "4.4.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d04704f56c2cde07f43e8e2c154b43f216dc5c92fc98ada720177362f953b956" -dependencies = [ - "clap_builder", - "clap_derive", -] - -[[package]] -name = "clap_builder" -version = "4.4.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e231faeaca65ebd1ea3c737966bf858971cd38c3849107aa3ea7de90a804e45" -dependencies = [ - "anstream", - "anstyle", - "clap_lex", - "strsim", -] - -[[package]] -name = "clap_derive" -version = "4.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0862016ff20d69b84ef8247369fabf5c008a7417002411897d40ee1f4532b873" -dependencies = [ - "heck", - "proc-macro2", - "quote", - "syn 2.0.38", -] - -[[package]] -name = "clap_lex" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd7cc57abe963c6d3b9d8be5b06ba7c8957a930305ca90304f24ef040aa6f961" - -[[package]] -name = "colorchoice" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" - -[[package]] -name = "console" -version = "0.15.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c926e00cc70edefdc64d3a5ff31cc65bb97a3460097762bd23afb4d8145fccf8" -dependencies = [ - "encode_unicode", - "lazy_static", - "libc", - "unicode-width", - "windows-sys 0.45.0", -] - [[package]] name = "crc32fast" version = "1.3.2" @@ -301,20 +194,11 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" -[[package]] -name = "encode_unicode" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" - [[package]] name = "esaxx-rs" version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6" -dependencies = [ - "cc", -] [[package]] name = "fastembed" @@ -326,6 +210,7 @@ dependencies = [ "ndarray", "ort", "rayon", + "serde_json", "tar", "tokenizers", ] @@ -339,7 +224,7 @@ dependencies = [ "cfg-if", "libc", "redox_syscall", - "windows-sys 0.48.0", + "windows-sys", ] [[package]] @@ -388,12 +273,6 @@ dependencies = [ "crunchy", ] -[[package]] -name = "heck" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" - [[package]] name = "ident_case" version = "1.0.1" @@ -410,28 +289,6 @@ dependencies = [ "unicode-normalization", ] -[[package]] -name = "indicatif" -version = "0.17.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb28741c9db9a713d93deb3bb9515c20788cef5815265bee4980e87bde7e0f25" -dependencies = [ - "console", - "instant", - "number_prefix", - "portable-atomic", - "unicode-width", -] - -[[package]] -name = "instant" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" -dependencies = [ - "cfg-if", -] - [[package]] name = "itertools" version = "0.11.0" @@ -625,12 +482,6 @@ dependencies = [ "autocfg", ] -[[package]] -name = "number_prefix" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" - [[package]] name = "once_cell" version = "1.18.0" @@ -704,12 +555,6 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" -[[package]] -name = "portable-atomic" -version = "1.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31114a898e107c51bb1609ffaf55a0e011cf6a4d7f1170d0015a165082c0338b" - [[package]] name = "ppv-lite86" version = "0.2.17" @@ -1040,11 +885,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9be88c795d8b9f9c4002b3a8f26a6d0876103a6f523b32ea3bac52d8560c17c" dependencies = [ "aho-corasick", - "clap", "derive_builder", "esaxx-rs", "getrandom", - "indicatif", "itertools", "lazy_static", "log", @@ -1134,12 +977,6 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" -[[package]] -name = "unicode-width" -version = "0.1.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" - [[package]] name = "unicode_categories" version = "0.1.1" @@ -1178,12 +1015,6 @@ dependencies = [ "percent-encoding", ] -[[package]] -name = "utf8parse" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" - [[package]] name = "vswhom" version = "0.1.0" @@ -1302,37 +1133,13 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" -[[package]] -name = "windows-sys" -version = "0.45.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" -dependencies = [ - "windows-targets 0.42.2", -] - [[package]] name = "windows-sys" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets 0.48.5", -] - -[[package]] -name = "windows-targets" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" -dependencies = [ - "windows_aarch64_gnullvm 0.42.2", - "windows_aarch64_msvc 0.42.2", - "windows_i686_gnu 0.42.2", - "windows_i686_msvc 0.42.2", - "windows_x86_64_gnu 0.42.2", - "windows_x86_64_gnullvm 0.42.2", - "windows_x86_64_msvc 0.42.2", + "windows-targets", ] [[package]] @@ -1341,93 +1148,51 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ - "windows_aarch64_gnullvm 0.48.5", - "windows_aarch64_msvc 0.48.5", - "windows_i686_gnu 0.48.5", - "windows_i686_msvc 0.48.5", - "windows_x86_64_gnu 0.48.5", - "windows_x86_64_gnullvm 0.48.5", - "windows_x86_64_msvc 0.48.5", + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", ] -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" - [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" -[[package]] -name = "windows_aarch64_msvc" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" - [[package]] name = "windows_aarch64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" -[[package]] -name = "windows_i686_gnu" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" - [[package]] name = "windows_i686_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" -[[package]] -name = "windows_i686_msvc" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" - [[package]] name = "windows_i686_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" -[[package]] -name = "windows_x86_64_gnu" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" - [[package]] name = "windows_x86_64_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" - [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" -[[package]] -name = "windows_x86_64_msvc" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" - [[package]] name = "windows_x86_64_msvc" version = "0.48.5" diff --git a/Cargo.toml b/Cargo.toml index a70ec3a..3925e9d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fastembed" -version = "1.4.3" +version = "1.4.2" edition = "2021" description = "Rust implementation of https://github.com/qdrant/fastembed" license = "MIT" @@ -18,5 +18,6 @@ minreq = { version = "2.10", default-features = false, features = ["https-rustls ndarray = { version = "0.15", default-features = false } ort = { version = "1", features = ["load-dynamic"] } rayon = { version = "1.7", default-features = false } +serde_json = {version = "1"} tar = { version = "0.4", default-features = false } -tokenizers = { version = "0.14" } +tokenizers = { version = "0.14", default-features = false, features = ["onig"]} diff --git a/src/lib.rs b/src/lib.rs index 2a9f4af..69f367a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -76,6 +76,7 @@ //! use std::{ + fs::File, path::{Path, PathBuf}, thread::available_parallelism, }; @@ -90,7 +91,7 @@ use rayon::{ slice::ParallelSlice, }; use tar::Archive; -use tokenizers::{PaddingParams, PaddingStrategy, TruncationParams}; +use tokenizers::{AddedToken, PaddingParams, PaddingStrategy, TruncationParams}; const DEFAULT_BATCH_SIZE: usize = 256; const DEFAULT_MAX_LENGTH: usize = 512; @@ -177,7 +178,8 @@ pub trait EmbeddingBase> { /// Rust representation of the FlagEmbedding model pub struct FlagEmbedding { tokenizer: Tokenizer, - model: Session, + session: Session, + model: EmbeddingModel, } impl FlagEmbedding { @@ -198,37 +200,28 @@ impl FlagEmbedding { let threads = available_parallelism()?.get() as i16; let model_path = - FlagEmbedding::retrieve_model(model_name, &cache_dir, show_download_message)?; + FlagEmbedding::retrieve_model(model_name.clone(), &cache_dir, show_download_message)?; let environment = Environment::builder() .with_name("Fastembed") .with_execution_providers(execution_providers) .build()?; - let model = SessionBuilder::new(&environment.into())? + let session = SessionBuilder::new(&environment.into())? .with_optimization_level(GraphOptimizationLevel::Level3)? .with_intra_threads(threads)? .with_model_from_file(model_path.join("model_optimized.onnx"))?; - let mut tokenizer = - tokenizers::Tokenizer::from_file(model_path.join("tokenizer.json")).unwrap(); - let tokenizer: Tokenizer = tokenizer - .with_truncation(Some(TruncationParams { - max_length, - ..Default::default() - })) - .unwrap() - .with_padding(Some(PaddingParams { - strategy: PaddingStrategy::Fixed(max_length), - pad_token: "[PAD]".into(), - ..Default::default() - })) - .clone(); - Ok(Self::new(tokenizer, model)) + let tokenizer = FlagEmbedding::load_tokenizer(model_path, max_length)?; + Ok(Self::new(tokenizer, session, model_name)) } /// Private method to return an instance - fn new(tokenizer: Tokenizer, model: Session) -> Self { - Self { tokenizer, model } + fn new(tokenizer: Tokenizer, session: Session, model: EmbeddingModel) -> Self { + Self { + tokenizer, + session, + model, + } } /// Download and unpack the model from Google Cloud Storage @@ -286,6 +279,67 @@ impl FlagEmbedding { Ok(output_path) } + fn load_tokenizer(model_path: PathBuf, max_length: usize) -> Result { + let config_path = model_path.join("config.json"); + let file = File::open(config_path)?; + let config: serde_json::Value = serde_json::from_reader(file)?; + + let tokenizer_config_path = model_path.join("tokenizer_config.json"); + let file = File::open(tokenizer_config_path)?; + let tokenizer_config: serde_json::Value = serde_json::from_reader(file)?; + + let special_tokens_map_path = model_path.join("special_tokens_map.json"); + let file = File::open(special_tokens_map_path)?; + let special_tokens_map: serde_json::Value = serde_json::from_reader(file)?; + + let tokenizer_path = model_path.join("tokenizer.json"); + let mut tokenizer = + tokenizers::Tokenizer::from_file(tokenizer_path).map_err(|e| anyhow::Error::msg(e))?; + + //For BGEBaseSmall, the model_max_length value is set to 1000000000000000019884624838656. Which fits in a f64 + let model_max_length = tokenizer_config["model_max_length"].as_f64().unwrap(); + let max_length = max_length.min(model_max_length as usize); + let pad_id = config["pad_token_id"] + .as_u64() + .expect("couldn't parse pad_token_id") as u32; + let pad_token = tokenizer_config["pad_token"].as_str().unwrap().into(); + + let mut tokenizer = tokenizer + .with_padding(Some(PaddingParams { + strategy: PaddingStrategy::Fixed(max_length), + pad_token, + pad_id, + ..Default::default() + })) + .with_truncation(Some(TruncationParams { + max_length, + ..Default::default() + })) + .map_err(|e| anyhow::Error::msg(e))? + .clone(); + if let serde_json::Value::Object(root_object) = special_tokens_map { + for (_, value) in root_object.iter() { + if value.is_string() { + tokenizer.add_special_tokens(&[AddedToken { + content: value.as_str().unwrap().into(), + special: true, + ..Default::default() + }]); + } else if value.is_object() { + tokenizer.add_special_tokens(&[AddedToken { + content: value["content"].as_str().unwrap().into(), + special: true, + single_word: value["single_word"].as_bool().unwrap(), + lstrip: value["lstrip"].as_bool().unwrap(), + rstrip: value["rstrip"].as_bool().unwrap(), + normalized: value["normalized"].as_bool().unwrap(), + }]); + } + } + } + Ok(tokenizer) + } + /// Retrieve a list of supported modelsc pub fn list_supported_models() -> Vec { vec![ModelInfo { @@ -371,13 +425,19 @@ impl + Send + Sync> EmbeddingBase for FlagEmbedding { )?) .into_dyn(); - // Run the model with inputs - let outputs = self.model.run(vec![ - Value::from_array(self.model.allocator(), &inputs_ids_array)?, - Value::from_array(self.model.allocator(), &attention_mask_array)?, - Value::from_array(self.model.allocator(), &token_type_ids_array)?, - ])?; + let mut inputs = vec![ + Value::from_array(self.session.allocator(), &inputs_ids_array)?, + Value::from_array(self.session.allocator(), &attention_mask_array)?, + Value::from_array(self.session.allocator(), &token_type_ids_array)?, + ]; + + // Remove the token_type_ids_array if the model is MLE5Large + if let EmbeddingModel::MLE5Large = self.model { + inputs.pop(); + } + // Run the model with inputs + let outputs = self.session.run(inputs)?; // Extract and normalize embeddings let output_data = outputs[0].try_extract::()?; let view = output_data.view(); @@ -441,7 +501,7 @@ fn get_embeddings(data: &[f32], dimensions: &[usize]) -> Vec { #[cfg(test)] mod tests { use super::*; - const epsilon: f32 = 1e-4; + const EPSILON: f32 = 1e-4; #[test] fn test_bgesmall() { @@ -462,7 +522,7 @@ mod tests { for (i, v) in expected.into_iter().enumerate() { let difference = (v - embeddings[0][i]).abs(); - assert!(difference < epsilon, "Difference: {}", difference) + assert!(difference < EPSILON, "Difference: {}", difference) } } @@ -484,7 +544,7 @@ mod tests { for (i, v) in expected.into_iter().enumerate() { let difference = (v - embeddings[0][i]).abs(); - assert!(difference < epsilon, "Difference: {}", difference) + assert!(difference < EPSILON, "Difference: {}", difference) } } @@ -507,30 +567,30 @@ mod tests { for (i, v) in expected.into_iter().enumerate() { let difference = (v - embeddings[0][i]).abs(); - assert!(difference < epsilon, "Difference: {}", difference) + assert!(difference < EPSILON, "Difference: {}", difference) } } - // #[test] - // fn test_mle5large() { - // let model: FlagEmbedding = FlagEmbedding::try_new(InitOptions { - // model_name: EmbeddingModel::MLE5Large, - // ..Default::default() - // }) - // .unwrap(); - - // let expected: Vec = vec![ - // 0.00961, 0.00443, 0.00658, -0.03532, 0.00703, -0.02878, -0.03671, 0.03482, 0.06343, - // -0.04731, - // ]; - // let documents = vec!["hello world"]; - - // // Generate embeddings with the default batch size, 256 - // let embeddings = model.embed(documents, None).unwrap(); - - // for (i, v) in expected.into_iter().enumerate() { - // let difference = (v - embeddings[0][i]).abs(); - // assert!(difference < epsilon, "Difference: {}", difference) - // } - // } + #[test] + fn test_mle5large() { + let model: FlagEmbedding = FlagEmbedding::try_new(InitOptions { + model_name: EmbeddingModel::MLE5Large, + ..Default::default() + }) + .unwrap(); + + let expected: Vec = vec![ + 0.00961, 0.00443, 0.00658, -0.03532, 0.00703, -0.02878, -0.03671, 0.03482, 0.06343, + -0.04731, + ]; + let documents = vec!["hello world"]; + + // Generate embeddings with the default batch size, 256 + let embeddings = model.embed(documents, None).unwrap(); + + for (i, v) in expected.into_iter().enumerate() { + let difference = (v - embeddings[0][i]).abs(); + assert!(difference < EPSILON, "Difference: {}", difference) + } + } }