Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SmolLM2 #1848

Merged
merged 20 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ Every model is written from scratch to maximize performance and remove layers of
| Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) |
| Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) |
| QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) |
| SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) |
| Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) |
| StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |
Expand Down
76 changes: 75 additions & 1 deletion litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2134,10 +2134,10 @@ def norm_class(self) -> Type:

configs.extend(qwq)


#############
# Salamandra
#############

salamandra = [
# https://huggingface.co/BSC-LT/salamandra-2b-instruct/blob/main/config.json
dict(
Expand Down Expand Up @@ -2189,4 +2189,78 @@ def norm_class(self) -> Type:
configs.append(copy)


###############
# SmolLM2
###############
smollm2 = [
# https://huggingface.co/HuggingFaceTB/SmolLM2-135M/blob/main/config.json
dict(
name="SmolLM2-135M{}",
hf_config=dict(org="HuggingFaceTB", name="SmolLM2-135M{}"),
block_size=8192,
vocab_size=49152,
padded_vocab_size=49152,
n_layer=30,
n_head=9,
n_embd=576,
n_query_groups=3,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=1536,
rope_base=100000,
norm_eps=1e-5,
),
# https://huggingface.co/HuggingFaceTB/SmolLM2-360M/blob/main/config.json
dict(
name="SmolLM2-360M{}",
hf_config=dict(org="HuggingFaceTB", name="SmolLM2-360M{}"),
block_size=8192,
vocab_size=49152,
padded_vocab_size=49152,
n_layer=32,
n_head=15,
n_embd=960,
n_query_groups=5,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=2560,
rope_base=100000,
norm_eps=1e-5,
),
# https://huggingface.co/HuggingFaceTB/SmolLM2-1.7B/blob/main/config.json
dict(
name="SmolLM2-1.7B{}",
hf_config=dict(org="HuggingFaceTB", name="SmolLM2-1.7B{}"),
block_size=8192,
vocab_size=49152,
padded_vocab_size=49152,
n_layer=24,
n_head=32,
n_embd=2048,
n_query_groups=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=8192,
rope_base=130000,
norm_eps=1e-5,
),
]

for c in smollm2:
for kind in ("", "-Instruct"):
copy = deepcopy(c)
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
configs.append(copy)


name_to_config = {config["name"]: config for config in configs}
9 changes: 9 additions & 0 deletions litgpt/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,12 @@ def apply(self, prompt: str, **kwargs: str) -> str:
return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"


class SmolLM2(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
system_message = "You are a helpful AI assistant named SmolLM, trained by Hugging Face"
return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"


# Maps prompt style names to PromptStyle classes
prompt_styles: Dict[str, Type[PromptStyle]] = {
# Dataset-specific prompt styles
Expand All @@ -326,6 +332,7 @@ def apply(self, prompt: str, **kwargs: str) -> str:
"qwen2.5": Qwen2_5,
"qwen2.5-math": Qwen2_5_Math,
"qwq": QwQ,
"smollm2": SmolLM2,
"salamandra": Salamandra,
}

Expand Down Expand Up @@ -371,6 +378,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
return Qwen2_5()
if re.search(r"QwQ-.*", model_name):
return QwQ()
if re.search(r"SmolLM2.*-Instruct", model_name):
return SmolLM2()
if re.search(r"salamandra-.*-instruct", model_name):
return Salamandra()
return Default()
Expand Down
2 changes: 1 addition & 1 deletion litgpt/scripts/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def find_weight_files(repo_id: str, access_token: Optional[str]) -> Tuple[List[s
with gated_repo_catcher(repo_id, access_token):
info = repo_info(repo_id, token=access_token)
filenames = [f.rfilename for f in info.siblings]
bins = list(filter_repo_objects(items=filenames, allow_patterns=["*.bin*"]))
bins = list(filter_repo_objects(items=filenames, allow_patterns=["*model*.bin*"]))
safetensors = list(filter_repo_objects(items=filenames, allow_patterns=["*.safetensors*"]))
return bins, safetensors

Expand Down
4 changes: 3 additions & 1 deletion litgpt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def token_to_id(self, token: str) -> int:
raise ValueError(f"token {token!r} not found in the collection.")
return id_

def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
if not (tokenizer_config_path := checkpoint_dir / "tokenizer_config.json").is_file():
return False
with open(tokenizer_config_path, encoding="utf-8") as fp:
Expand All @@ -96,6 +96,8 @@ def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
# `PreTrainedTokenizerFast`
if checkpoint_dir.stem.startswith(("Meta-Llama-3", "Llama-3")):
return True
if checkpoint_dir.stem.startswith("SmolLM2") and checkpoint_dir.name.endswith("Instruct"):
return True
if "add_bos_token" in config:
return config["add_bos_token"]
# if `add_bos_token` isn't in the config file, but LLaMA tokenizer is used - return True.
Expand Down
61 changes: 61 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,7 @@ def test_against_original_qwen_2_5(model_name, device, dtype):
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)


@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("salamandra-2b", "salamandra-7b"))
@pytest.mark.parametrize(
Expand Down Expand Up @@ -910,6 +911,66 @@ def test_against_original_salamandra(model_name, device, dtype):
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)


@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("SmolLM2-135M", "SmolLM2-360M", "SmolLM2-1.7B"))
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_original_smollm2(model_name, device, dtype):
torch.set_default_dtype(dtype)

ours_config = Config.from_name(
model_name,
padded_vocab_size=10000,
n_layer=2,
n_head=8,
n_embd=32,
n_query_groups=2,
intermediate_size=86,
)
T = 5
theirs_config = LlamaConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size

theirs_model = LlamaForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)

# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)


@RunIf(dynamo=True)
Expand Down
7 changes: 7 additions & 0 deletions tutorials/download_model_weights.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights.
| Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) |
| QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) |
| RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) |
| SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) |
| StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) |
| StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |
Expand Down Expand Up @@ -122,6 +123,12 @@ google/gemma-2b-it
google/gemma-7b
google/gemma-7b-it
h2oai/h2o-danube2-1.8b-chat
HuggingFaceTB/SmolLM2-135M
HuggingFaceTB/SmolLM2-135M-Instruct
HuggingFaceTB/SmolLM2-360M
HuggingFaceTB/SmolLM2-360M-Instruct
HuggingFaceTB/SmolLM2-1.7B
HuggingFaceTB/SmolLM2-1.7B-Instruct
lmsys/longchat-13b-16k
lmsys/longchat-7b-16k
lmsys/vicuna-13b-v1.3
Expand Down
Loading