From a8b88cc2903f32e83c2e8ec61f93b620ab25e8bb Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Mon, 2 Dec 2024 19:08:48 -0500 Subject: [PATCH 01/16] add smollm2 --- litgpt/config.py | 73 +++++++++++++++++++++++++++++++++++++++++++++++ litgpt/prompts.py | 11 +++++++ 2 files changed, 84 insertions(+) diff --git a/litgpt/config.py b/litgpt/config.py index 684f3f78be..3a9370d2fd 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -2043,4 +2043,77 @@ def norm_class(self) -> Type: configs.extend(qwq) +############### +# SmolLM2 +############### +smollm2 = [ + # https://huggingface.co/HuggingFaceTB/SmolLM2-135M/blob/main/config.json + dict( + name="SmolLM2-135M{}", + hf_config=dict(org="HuggingFaceTB", name="SmolLM2-135M{}"), + block_size=8192, + vocab_size=49152, + padded_vocab_size=49152, + n_layer=30, + n_head=9, + n_embd=576, + n_query_groups=3, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=1536, + rope_base=100000, + norm_eps=1e-5, + ), + # https://huggingface.co/HuggingFaceTB/SmolLM2-360M/blob/main/config.json + dict( + name="SmolLM2-360M{}", + hf_config=dict(org="HuggingFaceTB", name="SmolLM2-360M{}"), + block_size=8192, + vocab_size=49152, + padded_vocab_size=49152, + n_layer=32, + n_head=15, + n_embd=960, + n_query_groups=5, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=2560, + rope_base=100000, + norm_eps=1e-5, + ), + # https://huggingface.co/HuggingFaceTB/SmolLM2-1.7B/blob/main/config.json + dict( + name="SmolLM2-1.7B{}", + hf_config=dict(org="HuggingFaceTB", name="SmolLM2-1.7B{}"), + block_size=8192, + vocab_size=49152, + padded_vocab_size=49152, + n_layer=24, + n_head=32, + n_embd=2048, + n_query_groups=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=8192, + rope_base=130000, + norm_eps=1e-5, + ), +] + +for c in smollm2: + for kind in ("", "-Instruct"): + copy = deepcopy(c) + copy["name"] = c["name"].format(kind) + copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) + configs.append(copy) + name_to_config = {config["name"]: config for config in configs} diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 5f5fd14494..1c7b010838 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -284,6 +284,11 @@ def apply(self, prompt: str, **kwargs: str) -> str: system_message = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant." return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" +class SmolLM2(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + system_message = "You are a helpful AI assistant named SmolLM, trained by Hugging Face" + return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + class QwQ(PromptStyle): def apply(self, prompt: str, **kwargs: str) -> str: @@ -316,6 +321,7 @@ def apply(self, prompt: str, **kwargs: str) -> str: "olmo": OLMo, "qwen2.5": Qwen2_5, "qwq": QwQ, + "smollm2": SmolLM2 # SmolLM uses a different template } @@ -356,8 +362,13 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return OLMo() if re.search(r"Qwen2\.5-.*", model_name): return Qwen2_5() +<<<<<<< HEAD if re.search(r"QwQ-.*", model_name): return QwQ() +======= + if re.search(r"SmolLM2.*", model_name): + return SmolLM2() +>>>>>>> 961a30a (added smollm2) return Default() From 76a2f014ab28f6f1290feeaad8be8dc7326360c6 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Mon, 2 Dec 2024 19:17:41 -0500 Subject: [PATCH 02/16] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 3856a332ea..e5cce78d99 100644 --- a/README.md +++ b/README.md @@ -141,6 +141,7 @@ Every model is written from scratch to maximize performance and remove layers of | StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | | StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | | StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | +| SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) | | TinyLlama | 1.1B | Zhang et al. | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) | From 805bdcf1be4caadff18952ee577063a9a2ed8e19 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Mon, 2 Dec 2024 19:19:54 -0500 Subject: [PATCH 03/16] Update download_model_weights.md --- tutorials/download_model_weights.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index 509218ac96..62602f91df 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -38,6 +38,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights. | Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) | | QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) | | RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) | +| SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) | | StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | | StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | | StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | @@ -194,6 +195,12 @@ Qwen/Qwen2.5-Coder-14B-Instruct Qwen/Qwen2.5-Coder-32B Qwen/Qwen2.5-Coder-32B-Instruct Qwen/QwQ-32B-Preview +HuggingFaceTB/SmolLM2-135M +HuggingFaceTB/SmolLM2-135M-Instruct +HuggingFaceTB/SmolLM2-360M +HuggingFaceTB/SmolLM2-360M-Instruct +HuggingFaceTB/SmolLM2-1.7B +HuggingFaceTB/SmolLM2-1.7B-Instruct stabilityai/FreeWilly2 stabilityai/stable-code-3b stabilityai/stablecode-completion-alpha-3b From b8079d99e526bcd8806c0f35f77ab3aac2862e21 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Mon, 2 Dec 2024 19:20:18 -0500 Subject: [PATCH 04/16] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e5cce78d99..c270bddf8e 100644 --- a/README.md +++ b/README.md @@ -138,10 +138,10 @@ Every model is written from scratch to maximize performance and remove layers of | Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/) | | Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) | | QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) | +| SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) | | StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | | StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | | StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | -| SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) | | TinyLlama | 1.1B | Zhang et al. | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) | From ea3639d8f85f7b1b50d27966ca6cb24f756d336e Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Mon, 2 Dec 2024 19:26:25 -0500 Subject: [PATCH 05/16] fix: forgot to resolve conflicts --- litgpt/prompts.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 1c7b010838..6a99004de2 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -362,13 +362,10 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return OLMo() if re.search(r"Qwen2\.5-.*", model_name): return Qwen2_5() -<<<<<<< HEAD if re.search(r"QwQ-.*", model_name): return QwQ() -======= if re.search(r"SmolLM2.*", model_name): return SmolLM2() ->>>>>>> 961a30a (added smollm2) return Default() From 07a1954d5570e8a8125469c838b8781f2472d686 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Mon, 2 Dec 2024 20:31:53 -0500 Subject: [PATCH 06/16] fix: attempt to fix SmolLM2 tokenizer --- litgpt/tokenizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litgpt/tokenizer.py b/litgpt/tokenizer.py index a81c59aa2d..ed78ca550d 100644 --- a/litgpt/tokenizer.py +++ b/litgpt/tokenizer.py @@ -94,7 +94,7 @@ def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool: config = json.load(fp) # for LlaMA-3 tokenizer there is no `add_bos_token` at all and `tokenizer_class` is only # `PreTrainedTokenizerFast` - if checkpoint_dir.stem.startswith(("Meta-Llama-3", "Llama-3")): + if checkpoint_dir.stem.startswith(("Meta-Llama-3", "Llama-3", "SmolLM2")): return True if "add_bos_token" in config: return config["add_bos_token"] From 5c726b052e2f1a43b4d418dff6e26e2809164251 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Mon, 2 Dec 2024 21:01:11 -0500 Subject: [PATCH 07/16] revert --- litgpt/tokenizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litgpt/tokenizer.py b/litgpt/tokenizer.py index ed78ca550d..a81c59aa2d 100644 --- a/litgpt/tokenizer.py +++ b/litgpt/tokenizer.py @@ -94,7 +94,7 @@ def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool: config = json.load(fp) # for LlaMA-3 tokenizer there is no `add_bos_token` at all and `tokenizer_class` is only # `PreTrainedTokenizerFast` - if checkpoint_dir.stem.startswith(("Meta-Llama-3", "Llama-3", "SmolLM2")): + if checkpoint_dir.stem.startswith(("Meta-Llama-3", "Llama-3")): return True if "add_bos_token" in config: return config["add_bos_token"] From 6dd3353c8257dd963128d597070a169cba468884 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Tue, 3 Dec 2024 11:06:24 -0500 Subject: [PATCH 08/16] fix tokenizer --- litgpt/tokenizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litgpt/tokenizer.py b/litgpt/tokenizer.py index a81c59aa2d..ed78ca550d 100644 --- a/litgpt/tokenizer.py +++ b/litgpt/tokenizer.py @@ -94,7 +94,7 @@ def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool: config = json.load(fp) # for LlaMA-3 tokenizer there is no `add_bos_token` at all and `tokenizer_class` is only # `PreTrainedTokenizerFast` - if checkpoint_dir.stem.startswith(("Meta-Llama-3", "Llama-3")): + if checkpoint_dir.stem.startswith(("Meta-Llama-3", "Llama-3", "SmolLM2")): return True if "add_bos_token" in config: return config["add_bos_token"] From a46e3f8304acca13154520e2f1523521c96bac8e Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Tue, 3 Dec 2024 11:22:44 -0500 Subject: [PATCH 09/16] SmolLM2: fix different bos_token depending on base or instruct --- litgpt/tokenizer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/litgpt/tokenizer.py b/litgpt/tokenizer.py index ed78ca550d..41aa0dd08a 100644 --- a/litgpt/tokenizer.py +++ b/litgpt/tokenizer.py @@ -94,7 +94,9 @@ def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool: config = json.load(fp) # for LlaMA-3 tokenizer there is no `add_bos_token` at all and `tokenizer_class` is only # `PreTrainedTokenizerFast` - if checkpoint_dir.stem.startswith(("Meta-Llama-3", "Llama-3", "SmolLM2")): + if checkpoint_dir.stem.startswith(("Meta-Llama-3", "Llama-3")): + return True + if checkpoint_dir.stem.startswith("SmolLM2") and checkpoint_dir.stem.endswith("-Instruct"): return True if "add_bos_token" in config: return config["add_bos_token"] From d06b5ceb134d28b78c6fba1cf790ab099856bd71 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Tue, 3 Dec 2024 12:14:28 -0500 Subject: [PATCH 10/16] SmolLM2: fixed path specification for 1.7B-Instruct --- litgpt/tokenizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/litgpt/tokenizer.py b/litgpt/tokenizer.py index 41aa0dd08a..acfc0493a7 100644 --- a/litgpt/tokenizer.py +++ b/litgpt/tokenizer.py @@ -87,7 +87,7 @@ def token_to_id(self, token: str) -> int: raise ValueError(f"token {token!r} not found in the collection.") return id_ - def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool: + def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool: if not (tokenizer_config_path := checkpoint_dir / "tokenizer_config.json").is_file(): return False with open(tokenizer_config_path, encoding="utf-8") as fp: @@ -96,7 +96,7 @@ def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool: # `PreTrainedTokenizerFast` if checkpoint_dir.stem.startswith(("Meta-Llama-3", "Llama-3")): return True - if checkpoint_dir.stem.startswith("SmolLM2") and checkpoint_dir.stem.endswith("-Instruct"): + if checkpoint_dir.stem.startswith("SmolLM2") and checkpoint_dir.name.endswith("Instruct"): return True if "add_bos_token" in config: return config["add_bos_token"] From 2fafc05e4a8b5942071f41ecebfef10f24651f0a Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Tue, 3 Dec 2024 18:18:50 -0500 Subject: [PATCH 11/16] SmolLM2: add test --- tests/test_model.py | 59 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/tests/test_model.py b/tests/test_model.py index 3ca5e80599..210ed88108 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -851,6 +851,65 @@ def test_against_original_qwen_2_5(model_name, device, dtype): theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) + +@torch.inference_mode() +@pytest.mark.parametrize("model_name", ("SmolLM2-135M", "SmolLM2-1.7B")) +@pytest.mark.parametrize( + ("device", "dtype"), + [ + (torch.device("cpu"), torch.float32), + pytest.param( + torch.device("cuda"), + torch.float16, + marks=[ + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input + # is slightly different + pytest.mark.xfail(raises=AssertionError, strict=False), + RunIf(min_cuda_gpus=1), + ], + ), + ], +) +def test_against_original_smollm2(model_name, device, dtype): + torch.set_default_dtype(dtype) + + ours_config = Config.from_name( + model_name, + padded_vocab_size=10000, + n_layer=2, + n_head=8, + n_embd=32, + intermediate_size=86, + ) + T = 5 + theirs_config = LlamaConfig( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + num_attention_heads=ours_config.n_head, + num_hidden_layers=ours_config.n_layer, + intermediate_size=ours_config.intermediate_size, + max_position_embeddings=T, + rms_norm_eps=ours_config.norm_eps, + num_key_value_heads=ours_config.n_query_groups, + rope_theta=ours_config.rope_base, + attention_bias=ours_config.bias, + ) + assert ours_config.intermediate_size == theirs_config.intermediate_size + + theirs_model = LlamaForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + state_dict = {} + copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) + ours_model = GPT(ours_config).to(device) + ours_model.load_state_dict(state_dict) + + # test end to end + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) + @RunIf(dynamo=True) @torch.inference_mode() def test_model_compile(): From 8e5e44edc4e691b0c59b76b5161acc00ec1ca38b Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Tue, 3 Dec 2024 18:37:01 -0500 Subject: [PATCH 12/16] SmolLM2: minor fix on test_model.py script --- litgpt/prompts.py | 9 +++++---- tests/test_model.py | 1 + 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 6a99004de2..35c029fba7 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -284,15 +284,16 @@ def apply(self, prompt: str, **kwargs: str) -> str: system_message = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant." return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" -class SmolLM2(PromptStyle): + +class QwQ(PromptStyle): def apply(self, prompt: str, **kwargs: str) -> str: - system_message = "You are a helpful AI assistant named SmolLM, trained by Hugging Face" + system_message = "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step." return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" -class QwQ(PromptStyle): +class SmolLM2(PromptStyle): def apply(self, prompt: str, **kwargs: str) -> str: - system_message = "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step." + system_message = "You are a helpful AI assistant named SmolLM, trained by Hugging Face" return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" diff --git a/tests/test_model.py b/tests/test_model.py index 210ed88108..867f60d635 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -879,6 +879,7 @@ def test_against_original_smollm2(model_name, device, dtype): n_layer=2, n_head=8, n_embd=32, + n_query_groups=2, intermediate_size=86, ) T = 5 From deba0e82156bb37c2b22466d58a4cde5263ecbd9 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sun, 8 Dec 2024 01:20:28 -0500 Subject: [PATCH 13/16] smollm2 final revisions --- litgpt/prompts.py | 2 +- tutorials/download_model_weights.md | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 35c029fba7..be433ad0d4 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -365,7 +365,7 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return Qwen2_5() if re.search(r"QwQ-.*", model_name): return QwQ() - if re.search(r"SmolLM2.*", model_name): + if re.search(r"SmolLM2.*-Instruct", model_name): return SmolLM2() return Default() diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index 62602f91df..c18a40bdda 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -116,6 +116,12 @@ google/gemma-2b-it google/gemma-7b google/gemma-7b-it h2oai/h2o-danube2-1.8b-chat +HuggingFaceTB/SmolLM2-135M +HuggingFaceTB/SmolLM2-135M-Instruct +HuggingFaceTB/SmolLM2-360M +HuggingFaceTB/SmolLM2-360M-Instruct +HuggingFaceTB/SmolLM2-1.7B +HuggingFaceTB/SmolLM2-1.7B-Instruct lmsys/longchat-13b-16k lmsys/longchat-7b-16k lmsys/vicuna-13b-v1.3 @@ -195,12 +201,6 @@ Qwen/Qwen2.5-Coder-14B-Instruct Qwen/Qwen2.5-Coder-32B Qwen/Qwen2.5-Coder-32B-Instruct Qwen/QwQ-32B-Preview -HuggingFaceTB/SmolLM2-135M -HuggingFaceTB/SmolLM2-135M-Instruct -HuggingFaceTB/SmolLM2-360M -HuggingFaceTB/SmolLM2-360M-Instruct -HuggingFaceTB/SmolLM2-1.7B -HuggingFaceTB/SmolLM2-1.7B-Instruct stabilityai/FreeWilly2 stabilityai/stable-code-3b stabilityai/stablecode-completion-alpha-3b From 549115387c72a75301b3a270f0abfbb6028313e5 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sun, 8 Dec 2024 02:09:54 -0500 Subject: [PATCH 14/16] smollm2 very minor fix --- litgpt/prompts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 83a96ac43e..f5b59e4e90 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -327,7 +327,7 @@ def apply(self, prompt: str, **kwargs: str) -> str: "olmo": OLMo, "qwen2.5": Qwen2_5, "qwq": QwQ, - "smollm2": SmolLM2 # SmolLM uses a different template + "smollm2": SmolLM2, # SmolLM uses a different template "salamandra": Salamandra, } From a8f06da36b3cb7dcee0f5590b8003f82e9ec457a Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Sun, 8 Dec 2024 18:27:19 +0300 Subject: [PATCH 15/16] Filter bin files to include only model weights --- litgpt/scripts/download.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litgpt/scripts/download.py b/litgpt/scripts/download.py index c1af2af133..fc6c153fad 100644 --- a/litgpt/scripts/download.py +++ b/litgpt/scripts/download.py @@ -131,7 +131,7 @@ def find_weight_files(repo_id: str, access_token: Optional[str]) -> Tuple[List[s with gated_repo_catcher(repo_id, access_token): info = repo_info(repo_id, token=access_token) filenames = [f.rfilename for f in info.siblings] - bins = list(filter_repo_objects(items=filenames, allow_patterns=["*.bin*"])) + bins = list(filter_repo_objects(items=filenames, allow_patterns=["*model*.bin*"])) safetensors = list(filter_repo_objects(items=filenames, allow_patterns=["*.safetensors*"])) return bins, safetensors From 1f1e7372c1f5b2c7dd1e44536739ca8f4ce482ed Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Mon, 16 Dec 2024 00:21:42 -0500 Subject: [PATCH 16/16] smollm2: minor fixes --- litgpt/prompts.py | 2 +- tests/test_model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/litgpt/prompts.py b/litgpt/prompts.py index ab9cfee594..09b3277c7d 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -332,7 +332,7 @@ def apply(self, prompt: str, **kwargs: str) -> str: "qwen2.5": Qwen2_5, "qwen2.5-math": Qwen2_5_Math, "qwq": QwQ, - "smollm2": SmolLM2, # SmolLM uses a different template + "smollm2": SmolLM2, "salamandra": Salamandra, } diff --git a/tests/test_model.py b/tests/test_model.py index e3aad3bb0e..89e926d173 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -914,7 +914,7 @@ def test_against_original_salamandra(model_name, device, dtype): @torch.inference_mode() -@pytest.mark.parametrize("model_name", ("SmolLM2-135M", "SmolLM2-1.7B")) +@pytest.mark.parametrize("model_name", ("SmolLM2-135M", "SmolLM2-360M", "SmolLM2-1.7B")) @pytest.mark.parametrize( ("device", "dtype"), [