diff --git a/README.md b/README.md index 256e5d369..80c2b16c5 100644 --- a/README.md +++ b/README.md @@ -155,7 +155,7 @@ Please submit more benchmarks via raising an issue! ## Usage ### Installation and Build -To install mistral.rs, one should ensure they have Rust installed by following [this](https://rustup.rs/) link. Additionally, the Hugging Face token should be provided in `~/.cache/huggingface/token` when using the server to enable automatic download of gated models. +To install mistral.rs, one should ensure they have Rust installed by following [this](https://rustup.rs/) link. Additionally, the Hugging Face token should be provided in `~/.cache/huggingface/token` by running `huggingface-cli login` to enable automatic download of gated models. 1) Install required packages - `openssl` (ex., `sudo apt install libssl-dev`) @@ -169,9 +169,7 @@ To install mistral.rs, one should ensure they have Rust installed by following [ 3) Set HF token correctly (skip if already set or your model is not gated, or if you want to use the `token_source` parameters in Python or the command line.) ```bash - mkdir ~/.cache/huggingface - touch ~/.cache/huggingface/token - echo > ~/.cache/huggingface/token + huggingface-cli login ``` 4) Download the code @@ -220,7 +218,13 @@ To install mistral.rs, one should ensure they have Rust installed by following [ You can install Python support by following the guide [here](mistralrs-pyo3/README.md). -### Getting models from HF Hub +## Getting models + +There are 2 ways to run a model with mistral.rs: +- From Hugging Face Hub (easiest) +- From local files + +### Getting models from Hugging Face Hub Mistral.rs can automatically download models from HF Hub. To access gated models, you should provide a token source. They may be one of: - `literal:`: Load from a specified literal @@ -240,17 +244,12 @@ This is passed in the following ways: If token cannot be loaded, no token will be used (i.e. effectively using `none`). -## Loading models from local files:** +### Loading models from local files: -You can also instruct mistral.rs to load models locally by modifying the `*_model_id` arguments or options: +You can also instruct mistral.rs to load models fully locally by modifying the `*_model_id` arguments or options: ```bash ./mistralrs_server --port 1234 plain -m . -a mistral ``` -or - -```bash -./mistralrs-server gguf -m . -t . -f Phi-3-mini-128k-instruct-q4_K_M.gguf -``` Throughout mistral.rs, any model ID argument or option may be a local path and should contain the following files for each model ID option: - `--model-id` (server) or `model_id` (python/rust) or `--tok-model-id` (server) or `tok_model_id` (python/rust): @@ -267,7 +266,22 @@ Throughout mistral.rs, any model ID argument or option may be a local path and s - `--adapters-model-id` (server) or `adapters_model_id` (python/rust): - Adapters `.safetensors` and `adapter_config.json` files in their respective directories -### Run +### Running GGUF models locally + +To run GGUF models fully locally, you do not need to specify the tokenizer model ID argument and instead should pass a path to the +chat template JSON file (examples [here](chat_templates), you will need to create your own by specifying the chat template and `bos`/`eos` tokens) as well as specifying a local model ID. For example: + +```bash +./mistralrs-server --chat-template gguf -m . -f Phi-3-mini-128k-instruct-q4_K_M.gguf +``` + +The following tokenizer model types are currently supported. If you would like one to be added, please raise an issue. Otherwise, +please consider using the method demonstrated in examples below, where the tokenizer is sourced from Hugging Face. + +**Supported GGUF tokenizer types** +- `llama` + +## Run To start a server serving Mistral GGUF on `localhost:1234`, ```bash @@ -290,7 +304,7 @@ Additionally, for models without quantization, the model architecture should be You can launch interactive mode, a simple chat application running in the terminal, by passing `-i`: ```bash -./mistralrs_server -i gguf -t mistralai/Mistral-7B-Instruct-v0.1 -m TheBloke/Mistral-7B-Instruct-v0.1-GGUF -f mistral-7b-instruct-v0.1.Q4_K_M.gguf +./mistralrs_server -i plain -m microsoft/Phi-3-mini-128k-instruct -a phi3 ``` ### Quick examples: @@ -333,7 +347,7 @@ To start a server running Llama from GGML: To start a server running Mistral from safetensors. ```bash -./mistralrs_server --port 1234 gguf -m mistralai/Mistral-7B-Instruct-v0.1 +./mistralrs_server --port 1234 plain -m mistralai/Mistral-7B-Instruct-v0.1 -a mistral ``` ### Structured selection with a `.toml` file diff --git a/mistralrs-core/Cargo.toml b/mistralrs-core/Cargo.toml index 2ba047587..9d4c9c199 100644 --- a/mistralrs-core/Cargo.toml +++ b/mistralrs-core/Cargo.toml @@ -56,6 +56,7 @@ toml = "0.8.12" strum = { version = "0.26", features = ["derive"] } derive_more = { version = "0.99.17", default-features = false, features = ["from"] } tracing-subscriber.workspace = true +reqwest = { version = "0.12.4", features = ["blocking"] } [features] pyo3_macros = ["pyo3"] diff --git a/mistralrs-core/src/model_loader.rs b/mistralrs-core/src/model_loader.rs index b7438b0f0..3d61eb62c 100644 --- a/mistralrs-core/src/model_loader.rs +++ b/mistralrs-core/src/model_loader.rs @@ -150,22 +150,19 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result GGUFLoaderBuilder::new( GGUFSpecificConfig { repeat_last_n }, args.chat_template, - tokenizer_json, - Some(tok_model_id), + tok_model_id, quantized_model_id, quantized_filename, ) .build(), ModelSelected::XLoraGGUF { tok_model_id, - tokenizer_json, quantized_model_id, quantized_filename, repeat_last_n, @@ -175,7 +172,6 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result GGUFLoaderBuilder::new( GGUFSpecificConfig { repeat_last_n }, args.chat_template, - tokenizer_json, tok_model_id, quantized_model_id, quantized_filename, @@ -192,7 +188,6 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result anyhow::Result GGUFLoaderBuilder::new( GGUFSpecificConfig { repeat_last_n }, args.chat_template, - tokenizer_json, tok_model_id, quantized_model_id, quantized_filename, diff --git a/mistralrs-core/src/model_selected.rs b/mistralrs-core/src/model_selected.rs index 6642c3f8f..a9ed08e5d 100644 --- a/mistralrs-core/src/model_selected.rs +++ b/mistralrs-core/src/model_selected.rs @@ -95,13 +95,11 @@ pub enum ModelSelected { /// Select a GGUF model. GGUF { - /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path. + /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file. + /// If the `chat_template` is specified, then it will be treated as a path and used over remote files, + /// removing all remote accesses. #[arg(short, long)] - tok_model_id: String, - - /// Path to local tokenizer.json file. If this is specified it is used over any remote file. - #[arg(long)] - tokenizer_json: Option, + tok_model_id: Option, /// Quantized model ID to find the `quantized_filename`, only applicable if `quantized` is set. /// This may be a HF hub repo or a local path. @@ -119,14 +117,12 @@ pub enum ModelSelected { /// Select a GGUF model with X-LoRA. XLoraGGUF { - /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path. + /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file. + /// If the `chat_template` is specified, then it will be treated as a path and used over remote files, + /// removing all remote accesses. #[arg(short, long)] tok_model_id: Option, - /// Path to local tokenizer.json file. If this is specified it is used over any remote file. - #[arg(long)] - tokenizer_json: Option, - /// Quantized model ID to find the `quantized_filename`, only applicable if `quantized` is set. /// This may be a HF hub repo or a local path. #[arg(short = 'm', long)] @@ -156,14 +152,12 @@ pub enum ModelSelected { /// Select a GGUF model with LoRA. LoraGGUF { - /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path. + /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file. + /// If the `chat_template` is specified, then it will be treated as a path and used over remote files, + /// removing all remote accesses. #[arg(short, long)] tok_model_id: Option, - /// Path to local tokenizer.json file. If this is specified it is used over any remote file. - #[arg(long)] - tokenizer_json: Option, - /// Quantized model ID to find the `quantized_filename`, only applicable if `quantized` is set. /// This may be a HF hub repo or a local path. #[arg(short = 'm', long)] diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs index 7532ce9e4..ae3bb9dca 100644 --- a/mistralrs-core/src/pipeline/gguf.rs +++ b/mistralrs-core/src/pipeline/gguf.rs @@ -7,14 +7,14 @@ use crate::aici::bintokens::build_tok_trie; use crate::aici::toktree::TokTrie; use crate::lora::Ordering; use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig}; +use crate::pipeline::gguf_tokenizer::convert_ggml_to_hf_tokenizer; use crate::pipeline::{get_chat_template, Cache}; use crate::pipeline::{ChatTemplate, LocalModelPaths}; use crate::prefix_cacher::PrefixCacheManager; use crate::sequence::Sequence; -use crate::utils::tokenizer::get_tokenizer; use crate::utils::varbuilder_utils::{from_mmaped_safetensors, load_preload_adapters}; use crate::xlora_models::NonGranularState; -use crate::{do_sample, get_mut_arcmutex, get_paths, DeviceMapMetadata, DEBUG}; +use crate::{do_sample, get_mut_arcmutex, get_paths_gguf, DeviceMapMetadata, DEBUG}; use crate::{ models::quantized_llama::ModelWeights as QLlama, models::quantized_phi2::ModelWeights as QPhi, @@ -69,7 +69,6 @@ pub struct GGUFLoader { xlora_order: Option, no_kv_cache: bool, chat_template: Option, - tokenizer_json: Option, kind: ModelKind, tgt_non_granular_index: Option, } @@ -119,24 +118,24 @@ pub struct GGUFLoaderBuilder { xlora_order: Option, no_kv_cache: bool, chat_template: Option, - tokenizer_json: Option, tgt_non_granular_index: Option, } impl GGUFLoaderBuilder { + /// Create a loader builder for a GGUF model. `tok_model_id` is the model ID where you can find a + /// `tokenizer_config.json` file. If the `chat_template` is specified, then it will be treated as a + /// path and used over remote files, removing all remote accesses. pub fn new( config: GGUFSpecificConfig, chat_template: Option, - tokenizer_json: Option, - model_id: Option, + tok_model_id: Option, quantized_model_id: String, quantized_filename: String, ) -> Self { Self { config, chat_template, - tokenizer_json, - model_id, + model_id: tok_model_id, kind: ModelKind::QuantizedGGUF, quantized_filename, quantized_model_id, @@ -197,7 +196,6 @@ impl GGUFLoaderBuilder { xlora_order: self.xlora_order, no_kv_cache: self.no_kv_cache, chat_template: self.chat_template, - tokenizer_json: self.tokenizer_json, tgt_non_granular_index: self.tgt_non_granular_index, quantized_filename: Some(self.quantized_filename), quantized_model_id: Some(self.quantized_model_id), @@ -217,7 +215,6 @@ impl GGUFLoader { xlora_order: Option, no_kv_cache: bool, chat_template: Option, - tokenizer_json: Option, tgt_non_granular_index: Option, ) -> Self { let model_id = if let Some(id) = model_id { @@ -238,7 +235,6 @@ impl GGUFLoader { xlora_order, no_kv_cache, chat_template, - tokenizer_json, kind, tgt_non_granular_index, } @@ -279,7 +275,7 @@ impl Loader for GGUFLoader { mapper: DeviceMapMetadata, in_situ_quant: Option, ) -> Result>> { - let paths: anyhow::Result> = get_paths!( + let paths: anyhow::Result> = get_paths_gguf!( LocalModelPaths, &token_source, revision, @@ -360,6 +356,8 @@ impl Loader for GGUFLoader { info!("Debug is enabled, wrote the names and information about each tensor to `mistralrs_gguf_tensors.txt`."); } + let tokenizer = convert_ggml_to_hf_tokenizer(&model)?; + let mut is_lora = false; let model = match self.kind { ModelKind::QuantizedGGUF => match arch { @@ -480,8 +478,6 @@ impl Loader for GGUFLoader { _ => unreachable!(), }; - let tokenizer = get_tokenizer(paths.get_tokenizer_filename())?; - let gen_conf: Option = paths .get_gen_conf_filename() .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap()); diff --git a/mistralrs-core/src/pipeline/gguf_tokenizer.rs b/mistralrs-core/src/pipeline/gguf_tokenizer.rs new file mode 100644 index 000000000..1a8333616 --- /dev/null +++ b/mistralrs-core/src/pipeline/gguf_tokenizer.rs @@ -0,0 +1,256 @@ +use std::sync::atomic::Ordering; + +use anyhow::Result; +use candle_core::quantized::gguf_file::Content; +use tokenizers::{ + decoders::{self, byte_fallback::ByteFallback, fuse::Fuse, strip::Strip}, + models::unigram::Unigram, + normalizers::{self, Prepend, Replace}, + AddedToken, DecoderWrapper, ModelWrapper, NormalizerWrapper, Tokenizer, +}; +use tracing::info; + +use crate::DEBUG; + +pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result { + let model = content.metadata["tokenizer.ggml.model"] + .to_string() + .expect("GGUF tokenizer model is not a string.") + .clone(); + let tokens = content.metadata["tokenizer.ggml.tokens"] + .to_vec() + .expect("GGUF tokenizer tokens is not a vec.") + .iter() + .map(|t| t.to_string().expect("GGUF token is not a string.").clone()) + .collect::>(); + let added_tokens = content + .metadata + .get("tokenizer.ggml.added_tokens") + .map(|items| { + items + .to_vec() + .expect("GGUF tokenizer added_tokens is not a vec.") + .iter() + .map(|t| { + t.to_string() + .expect("GGUF added_token is not a string.") + .clone() + }) + .collect::>() + }); + let scores = content.metadata.get("tokenizer.ggml.scores").map(|items| { + items + .to_vec() + .expect("GGUF tokenizer scores is not a vec.") + .iter() + .map(|t| t.to_f32().expect("GGUF score is not a f32.")) + .collect::>() + }); + let merges = content.metadata.get("tokenizer.ggml.merges").map(|items| { + items + .to_vec() + .expect("GGUF tokenizer merges is not a vec.") + .iter() + .map(|t| t.to_string().expect("GGUF merges is not a string.").clone()) + .collect::>() + }); + + let unk = content.metadata["tokenizer.ggml.unknown_token_id"] + .to_u32() + .expect("GGUF unk token is not u32"); + + let eos = content.metadata["tokenizer.ggml.eos_token_id"] + .to_u32() + .expect("GGUF unk token is not u32"); + + let bos = content.metadata["tokenizer.ggml.bos_token_id"] + .to_u32() + .expect("GGUF unk token is not u32"); + + let (tokenizer, ty) = match model.as_str() { + "llama" | "replit" => { + // unigram + let scores = scores + .as_ref() + .expect("Expect `tokenizer.ggml.scores` for `llama` unigram tokeizer."); + let mut vocab = Vec::new(); + for (token, score) in tokens.iter().zip(scores) { + vocab.push((token.clone(), *score as f64)); + } + let unigram = + Unigram::from(vocab, Some(unk as usize), true).map_err(anyhow::Error::msg)?; + let mut tokenizer = Tokenizer::new(ModelWrapper::Unigram(unigram)); + tokenizer.with_decoder(decoders::sequence::Sequence::new(vec![ + DecoderWrapper::Replace(Replace::new("▁", " ").map_err(anyhow::Error::msg)?), + DecoderWrapper::ByteFallback(ByteFallback::new()), + DecoderWrapper::Fuse(Fuse::new()), + DecoderWrapper::Strip(Strip::new(' ', 1, 0)), + ])); + tokenizer.with_normalizer(normalizers::Sequence::new(vec![ + NormalizerWrapper::Prepend(Prepend::new("▁".to_string())), + NormalizerWrapper::Replace(Replace::new(" ", "▁").map_err(anyhow::Error::msg)?), + ])); + + tokenizer.add_special_tokens(&[AddedToken::from(tokens[bos as usize].clone(), true)]); + tokenizer.add_special_tokens(&[AddedToken::from(tokens[eos as usize].clone(), true)]); + tokenizer.add_special_tokens(&[AddedToken::from(tokens[unk as usize].clone(), true)]); + + (tokenizer, "unigram") + } + other => { + anyhow::bail!("Tokenizer model `{other}` not supported."); + } + }; + info!( + "GGUF tokenizer model is `{model}`, kind: `{}`, num tokens: {}, num added tokens: {}, num merges: {}, num scores: {}", + ty, + tokenizer.get_vocab_size(true), + added_tokens.as_ref().map(|x| x.len()).unwrap_or(0), + merges.as_ref().map(|x| x.len()).unwrap_or(0), + scores.as_ref().map(|x| x.len()).unwrap_or(0) + ); + if DEBUG.load(Ordering::Relaxed) { + info!("Tokenizer: {tokenizer:?}"); + } + Ok(tokenizer) +} + +mod tests { + use anyhow::Result; + use candle_core::quantized::gguf_file::Content; + use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; + use tokenizers::Tokenizer; + + use super::convert_ggml_to_hf_tokenizer; + + #[allow(dead_code)] + #[derive(Debug)] + enum TokenizerType { + /// Mistral v0.1 tokenizer + Llama, + Replit, + Gpt2, + Rwkv, + } + + #[allow(dead_code)] + fn get_gguf_tokenizer(tokenizer: TokenizerType) -> Result { + match tokenizer { + TokenizerType::Llama => { + let api = ApiBuilder::new().with_progress(true).build().unwrap(); + let api = api.repo(Repo::with_revision( + "TheBloke/Mistral-7B-Instruct-v0.1-GGUF".to_string(), + RepoType::Model, + "main".to_string(), + )); + + let filename = api.get("mistral-7b-instruct-v0.1.Q2_K.gguf").unwrap(); + let mut file = std::fs::File::open(&filename)?; + convert_ggml_to_hf_tokenizer( + &Content::read(&mut file) + .map_err(|e| e.with_path(filename)) + .map_err(anyhow::Error::msg)?, + ) + .map_err(anyhow::Error::msg) + } + other => anyhow::bail!("Cannot get testing HF tokenizer for type {other:?}"), + } + } + + #[allow(dead_code)] + fn get_hf_tokenizer(tokenizer: TokenizerType) -> Result { + match tokenizer { + TokenizerType::Llama => { + let api = ApiBuilder::new().with_progress(true).build().unwrap(); + let api = api.repo(Repo::with_revision( + "EricB/mistralrs_tests".to_string(), + RepoType::Model, + "main".to_string(), + )); + + let tokenizer_filename = api.get("tokenizer.json").unwrap(); + Ok(Tokenizer::from_file(tokenizer_filename).unwrap()) + } + other => anyhow::bail!("Cannot get testing HF tokenizer for type {other:?}"), + } + } + + #[allow(dead_code)] + fn get_test_passage() -> String { + let passage = reqwest::blocking::get("https://loripsum.net/api") + .expect("Failed to download sample text") + .bytes() + .expect("Failed to get bytes"); + String::from_utf8(passage.to_vec()).expect("Failed to convert sample text to string.") + } + + #[test] + fn test_encode_llama() -> Result<()> { + let passage = get_test_passage(); + let hf_tokenizer = get_hf_tokenizer(TokenizerType::Llama)?; + let gguf_tokenizer = get_gguf_tokenizer(TokenizerType::Llama)?; + + // Without special tokens + let hf_tokenized = hf_tokenizer + .encode(passage.as_str(), false) + .map_err(anyhow::Error::msg)?; + let gguf_tokenized = gguf_tokenizer + .encode(passage.as_str(), false) + .map_err(anyhow::Error::msg)?; + let hf_decoded = hf_tokenizer + .decode(hf_tokenized.get_ids(), false) + .map_err(anyhow::Error::msg)?; + let gguf_decoded = gguf_tokenizer + .decode(gguf_tokenized.get_ids(), false) + .map_err(anyhow::Error::msg)?; + assert_eq!(hf_decoded, gguf_decoded); + + // With special tokens + let hf_tokenized = hf_tokenizer + .encode(passage.as_str(), true) + .map_err(anyhow::Error::msg)?; + let gguf_tokenized = gguf_tokenizer + .encode(passage.as_str(), true) + .map_err(anyhow::Error::msg)?; + let hf_decoded = hf_tokenizer + .decode(hf_tokenized.get_ids(), true) + .map_err(anyhow::Error::msg)?; + let gguf_decoded = gguf_tokenizer + .decode(gguf_tokenized.get_ids(), true) + .map_err(anyhow::Error::msg)?; + assert_eq!(hf_decoded, gguf_decoded); + Ok(()) + } + + #[test] + fn test_decode_llama() -> Result<()> { + use rand::seq::SliceRandom; + use rand::thread_rng; + + let hf_tokenizer = get_hf_tokenizer(TokenizerType::Llama)?; + let gguf_tokenizer = get_gguf_tokenizer(TokenizerType::Llama)?; + + #[allow(clippy::cast_possible_truncation)] + let mut tokens = (0..hf_tokenizer.get_vocab_size(false) as u32).collect::>(); + tokens.shuffle(&mut thread_rng()); + + // Without skipping special tokens + let hf_decoded = hf_tokenizer + .decode(&tokens, false) + .map_err(anyhow::Error::msg)?; + let gguf_decoded = gguf_tokenizer + .decode(&tokens, false) + .map_err(anyhow::Error::msg)?; + assert_eq!(hf_decoded, gguf_decoded); + + // With skipping special tokens + let hf_decoded = hf_tokenizer + .decode(&tokens, true) + .map_err(anyhow::Error::msg)?; + let gguf_decoded = gguf_tokenizer + .decode(&tokens, true) + .map_err(anyhow::Error::msg)?; + assert_eq!(hf_decoded, gguf_decoded); + Ok(()) + } +} diff --git a/mistralrs-core/src/pipeline/macros.rs b/mistralrs-core/src/pipeline/macros.rs index 25068ccad..7f8f663d5 100644 --- a/mistralrs-core/src/pipeline/macros.rs +++ b/mistralrs-core/src/pipeline/macros.rs @@ -138,6 +138,89 @@ macro_rules! get_paths { }}; } +#[macro_export] +macro_rules! get_paths_gguf { + ($path_name:ident, $token_source:expr, $revision:expr, $this:expr, $quantized_model_id:expr, $quantized_filename:expr, $silent:expr) => {{ + let api = ApiBuilder::new() + .with_progress(!$silent) + .with_token(get_token($token_source)?) + .build()?; + let revision = $revision.unwrap_or("main".to_string()); + let api = api.repo(Repo::with_revision( + $this.model_id.clone(), + RepoType::Model, + revision.clone(), + )); + let model_id = std::path::Path::new(&$this.model_id); + + let chat_template = if let Some(ref p) = $this.chat_template { + if p.ends_with(".json") { + info!("Using chat template file at `{p}`"); + PathBuf::from_str(p)? + } else { + PathBuf::from_str("")? + } + } else { + $crate::api_get_file!( + api, + "tokenizer_config.json", + model_id + ) // Will be loaded from inside gguf file + }; + + let filenames = get_model_paths( + revision.clone(), + &$token_source, + &$quantized_model_id, + &$quantized_filename, + &api, + &model_id, + )?; + + let XLoraPaths { + adapter_configs, + adapter_safetensors, + classifier_path, + xlora_order, + xlora_config, + lora_preload_adapter_info, + } = get_xlora_paths( + $this.model_id.clone(), + &$this.xlora_model_id, + &$token_source, + revision.clone(), + &$this.xlora_order, + )?; + + let gen_conf = if $crate::api_dir_list!(api, model_id) + .collect::>() + .contains(&"generation_config.json".to_string()) + { + Some($crate::api_get_file!( + api, + "generation_config.json", + model_id + )) + } else { + None + }; + + Ok(Box::new($path_name { + tokenizer_filename: PathBuf::from_str("")?, + config_filename: PathBuf::from_str("")?, + filenames, + xlora_adapter_configs: adapter_configs, + xlora_adapter_filenames: adapter_safetensors, + classifier_path, + classifier_config: xlora_config, + xlora_ordering: xlora_order, + template_filename: chat_template, + gen_conf, + lora_preload_adapter_info, + })) + }}; +} + #[macro_export] macro_rules! normal_model_loader { ($paths:expr, $dtype:expr, $default_dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $real_device:expr) => {{ diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index de4135ec6..9d7dbee83 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -2,6 +2,7 @@ mod cache_manager; mod chat_template; mod ggml; mod gguf; +mod gguf_tokenizer; mod loaders; mod macros; mod normal; @@ -1297,8 +1298,25 @@ pub(crate) fn get_chat_template( paths: &Box, chat_template: &Option, ) -> ChatTemplate { + let template_filename = if paths.get_template_filename().to_string_lossy().is_empty() { + PathBuf::from( + chat_template + .as_ref() + .expect("A tokenizer config or chat template file path must be specified."), + ) + } else { + paths.get_template_filename().clone() + }; + if template_filename + .extension() + .expect("Template filename must be a file") + .to_string_lossy() + != "json" + { + panic!("Template filename {template_filename:?} must end with `.json`."); + } let template: ChatTemplate = - serde_json::from_str(&fs::read_to_string(paths.get_template_filename()).unwrap()).unwrap(); + serde_json::from_str(&fs::read_to_string(&template_filename).unwrap()).unwrap(); #[derive(Debug, serde::Deserialize)] struct SpecifiedTemplate { @@ -1313,7 +1331,7 @@ pub(crate) fn get_chat_template( info!("`tokenizer_config.json` does not contain a chat template, attempting to use specified JINJA chat template."); let mut deser: HashMap = - serde_json::from_str(&fs::read_to_string(paths.get_template_filename()).unwrap()).unwrap(); + serde_json::from_str(&fs::read_to_string(&template_filename).unwrap()).unwrap(); match chat_template.clone() { Some(t) => { diff --git a/mistralrs-core/src/sampler.rs b/mistralrs-core/src/sampler.rs index 520b139f0..a8da56c10 100644 --- a/mistralrs-core/src/sampler.rs +++ b/mistralrs-core/src/sampler.rs @@ -413,11 +413,7 @@ mod tests { #[allow(dead_code)] fn get_tokenizer() -> Tokenizer { - let api = ApiBuilder::new() - .with_progress(true) - .with_token(Some(std::env::var("TESTS_HF_TOKEN").unwrap())) - .build() - .unwrap(); + let api = ApiBuilder::new().with_progress(true).build().unwrap(); let api = api.repo(Repo::with_revision( "EricB/mistralrs_tests".to_string(), RepoType::Model, diff --git a/mistralrs-core/src/toml_selector.rs b/mistralrs-core/src/toml_selector.rs index 5bf67276c..478d940eb 100644 --- a/mistralrs-core/src/toml_selector.rs +++ b/mistralrs-core/src/toml_selector.rs @@ -65,7 +65,9 @@ enum TomlModelSelected { /// Select a GGUF model. #[allow(clippy::upper_case_acronyms)] GGUF { - /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path. + /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file. + /// If the `chat_template` is specified, then it will be treated as a path and used over remote files, + /// removing all remote accesses. tok_model_id: String, /// Quantized model ID to find the `quantized_filename`, only applicable if `quantized` is set. @@ -78,7 +80,9 @@ enum TomlModelSelected { /// Select a GGUF model with X-LoRA. XLoraGGUF { - /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path. + /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file. + /// If the `chat_template` is specified, then it will be treated as a path and used over remote files, + /// removing all remote accesses. tok_model_id: Option, /// Quantized model ID to find the `quantized_filename`, only applicable if `quantized` is set. @@ -101,7 +105,9 @@ enum TomlModelSelected { /// Select a GGUF model with LoRA. LoraGGUF { - /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path. + /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file. + /// If the `chat_template` is specified, then it will be treated as a path and used over remote files, + /// removing all remote accesses. tok_model_id: Option, /// Quantized model ID to find the `quantized_filename`, only applicable if `quantized` is set. @@ -299,7 +305,6 @@ fn loader_from_selected( repeat_last_n: args.repeat_last_n, }, args.chat_template, - args.tokenizer_json, Some(tok_model_id), quantized_model_id, quantized_filename, @@ -317,7 +322,6 @@ fn loader_from_selected( repeat_last_n: args.repeat_last_n, }, args.chat_template, - args.tokenizer_json, tok_model_id, quantized_model_id, quantized_filename, @@ -343,7 +347,6 @@ fn loader_from_selected( repeat_last_n: args.repeat_last_n, }, args.chat_template, - args.tokenizer_json, tok_model_id, quantized_model_id, quantized_filename, diff --git a/mistralrs-pyo3/API.md b/mistralrs-pyo3/API.md index 7d0387348..359ac00e8 100644 --- a/mistralrs-pyo3/API.md +++ b/mistralrs-pyo3/API.md @@ -22,11 +22,13 @@ Additionally, for models without quantization, the model architecture should be ```py class Which(Enum): + @dataclass class Plain: model_id: str arch: Architecture tokenizer_json: str | None = None repeat_last_n: int = 64 + @dataclass class XLora: arch: Architecture xlora_model_id: str @@ -35,6 +37,7 @@ class Which(Enum): model_id: str | None = None tokenizer_json: str | None = None repeat_last_n: int = 64 + @dataclass class Lora: arch: Architecture adapters_model_id: str @@ -42,12 +45,13 @@ class Which(Enum): model_id: str | None = None tokenizer_json: str | None = None repeat_last_n: int = 64 + @dataclass class GGUF: tok_model_id: str quantized_model_id: str quantized_filename: str - tokenizer_json: str | None = None repeat_last_n: int = 64 + @dataclass class XLoraGGUF: tok_model_id: str quantized_model_id: str @@ -55,22 +59,23 @@ class Which(Enum): xlora_model_id: str order: str tgt_non_granular_index: int | None = None - tokenizer_json: str | None = None repeat_last_n: int = 64 + @dataclass class LoraGGUF: tok_model_id: str quantized_model_id: str quantized_filename: str adapters_model_id: str order: str - tokenizer_json: str | None = None repeat_last_n: int = 64 + @dataclass class GGML: tok_model_id: str quantized_model_id: str quantized_filename: str tokenizer_json: str | None = None repeat_last_n: int = 64 + @dataclass class XLoraGGML: tok_model_id: str quantized_model_id: str @@ -80,6 +85,7 @@ class Which(Enum): tgt_non_granular_index: int | None = None tokenizer_json: str | None = None repeat_last_n: int = 64 + @dataclass class LoraGGML: tok_model_id: str quantized_model_id: str diff --git a/mistralrs-pyo3/mistralrs.pyi b/mistralrs-pyo3/mistralrs.pyi index a1239a557..f1d7c46c7 100644 --- a/mistralrs-pyo3/mistralrs.pyi +++ b/mistralrs-pyo3/mistralrs.pyi @@ -96,7 +96,6 @@ class Which(Enum): tok_model_id: str quantized_model_id: str quantized_filename: str - tokenizer_json: str | None = None repeat_last_n: int = 64 @dataclass class XLoraGGUF: @@ -106,7 +105,6 @@ class Which(Enum): xlora_model_id: str order: str tgt_non_granular_index: int | None = None - tokenizer_json: str | None = None repeat_last_n: int = 64 @dataclass class LoraGGUF: @@ -115,7 +113,6 @@ class Which(Enum): quantized_filename: str adapters_model_id: str order: str - tokenizer_json: str | None = None repeat_last_n: int = 64 @dataclass class GGML: diff --git a/mistralrs-pyo3/src/lib.rs b/mistralrs-pyo3/src/lib.rs index b1e5f8a83..ae0ec9d3b 100644 --- a/mistralrs-pyo3/src/lib.rs +++ b/mistralrs-pyo3/src/lib.rs @@ -167,7 +167,6 @@ fn parse_which( .build(arch.into()), Which::GGUF { tok_model_id, - tokenizer_json, quantized_model_id, quantized_filename, repeat_last_n, @@ -176,15 +175,13 @@ fn parse_which( repeat_last_n: repeat_last_n.unwrap_or(REPEAT_LAST_N_DEFAULT), }, chat_template, - tokenizer_json, - Some(tok_model_id), + tok_model_id, quantized_model_id, quantized_filename, ) .build(), Which::XLoraGGUF { tok_model_id, - tokenizer_json, quantized_model_id, quantized_filename, repeat_last_n, @@ -196,7 +193,6 @@ fn parse_which( repeat_last_n: repeat_last_n.unwrap_or(REPEAT_LAST_N_DEFAULT), }, chat_template, - tokenizer_json, tok_model_id, quantized_model_id, quantized_filename, @@ -214,7 +210,6 @@ fn parse_which( .build(), Which::LoraGGUF { tok_model_id, - tokenizer_json, quantized_model_id, quantized_filename, repeat_last_n, @@ -225,7 +220,6 @@ fn parse_which( repeat_last_n: repeat_last_n.unwrap_or(REPEAT_LAST_N_DEFAULT), }, chat_template, - tokenizer_json, tok_model_id, quantized_model_id, quantized_filename, diff --git a/mistralrs-pyo3/src/which.rs b/mistralrs-pyo3/src/which.rs index f7def2cfb..a5a33a612 100644 --- a/mistralrs-pyo3/src/which.rs +++ b/mistralrs-pyo3/src/which.rs @@ -56,8 +56,7 @@ pub enum Which { #[allow(clippy::upper_case_acronyms)] GGUF { - tok_model_id: String, - tokenizer_json: Option, + tok_model_id: Option, quantized_model_id: String, quantized_filename: String, repeat_last_n: Option, @@ -65,7 +64,6 @@ pub enum Which { XLoraGGUF { tok_model_id: Option, - tokenizer_json: Option, quantized_model_id: String, quantized_filename: String, repeat_last_n: Option, @@ -76,7 +74,6 @@ pub enum Which { LoraGGUF { tok_model_id: Option, - tokenizer_json: Option, quantized_model_id: String, quantized_filename: String, repeat_last_n: Option, diff --git a/mistralrs/examples/quantized/main.rs b/mistralrs/examples/quantized/main.rs index 37f60ef01..58f1ac92b 100644 --- a/mistralrs/examples/quantized/main.rs +++ b/mistralrs/examples/quantized/main.rs @@ -12,7 +12,6 @@ fn setup() -> anyhow::Result> { let loader = GGUFLoaderBuilder::new( GGUFSpecificConfig { repeat_last_n: 64 }, None, - None, Some("mistralai/Mistral-7B-Instruct-v0.1".to_string()), "TheBloke/Mistral-7B-Instruct-v0.1-GGUF".to_string(), "mistral-7b-instruct-v0.1.Q4_K_M.gguf".to_string(),