Skip to content

Commit

Permalink
refactor: Use tokenizer batch_encode
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Oct 3, 2023
1 parent 587b423 commit 4ebef86
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 4 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ use rayon::{
slice::ParallelSlice,
};
use tar::Archive;
use tokenizers::{Encoding, PaddingParams, PaddingStrategy, TruncationParams};
use tokenizers::{PaddingParams, PaddingStrategy, TruncationParams};

const DEFAULT_BATCH_SIZE: usize = 256;
const DEFAULT_MAX_LENGTH: usize = 512;
Expand Down Expand Up @@ -162,7 +162,7 @@ pub struct ModelInfo {
pub description: String,
}

/// Base class for implemnting an embedding model
/// Base for implementing an embedding model
pub trait EmbeddingBase<S: AsRef<str>> {
/// The base embedding method for generating sentence embeddings
fn embed(&self, texts: Vec<S>, batch_size: Option<usize>) -> Result<Vec<Embedding>>;
Expand Down Expand Up @@ -317,10 +317,8 @@ impl<S: AsRef<str> + Send + Sync> EmbeddingBase<S> for FlagEmbedding {
.par_chunks(batch_size)
.map(|batch| {
// Encode the texts in the batch
let encodings: Vec<Encoding> = batch
.iter()
.map(|text| self.tokenizer.encode(text.as_ref(), true).unwrap())
.collect();
let inputs = batch.iter().map(|text| text.as_ref()).collect();
let encodings = self.tokenizer.encode_batch(inputs, true).unwrap();

// Extract the encoding length and batch size
let encoding_length = encodings[0].len();
Expand Down

0 comments on commit 4ebef86

Please sign in to comment.