Skip to content

Commit

Permalink
Fix model export for 7b (#222)
Browse files Browse the repository at this point in the history
Summary:
- `n_layer` -> `n_layers`
- `str(checkpoint_path)`
- `strict=False`

Test Plan:
```
python generate.py --compile
--checkpoint-path="/home/kimishpatel/models/llama2/7b/consolidated.00.pth" --params-path="/home/kimishpatel/models/llama2/7b/params_32k_vocab.json" --prompt "Hello, my name is" --device cpu
```
  • Loading branch information
kimishpatel authored Apr 16, 2024
1 parent 32fdb5a commit d6f557c
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
2 changes: 1 addition & 1 deletion build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def _load_model_not_gguf(builder_args):
if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path):
checkpoint = checkpoint["model"]

model.load_state_dict(checkpoint, assign=True)
model.load_state_dict(checkpoint, assign=True, strict=False)

if builder_args.use_tp:
from tp import apply_tp
Expand Down
2 changes: 1 addition & 1 deletion build/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _create_pt_model(
) -> nn.Module:
llama_model_args = ModelArgs(
dim=gguf_model_args.embedding_length,
n_layer=gguf_model_args.block_count,
n_layers=gguf_model_args.block_count,
n_heads=gguf_model_args.attention.head_count,
n_local_heads=gguf_model_args.attention.head_count_kv,
vocab_size=gguf_model_args.vocab_size,
Expand Down
26 changes: 13 additions & 13 deletions build/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def find_multiple(n: int, k: int) -> int:
class ModelArgs:
block_size: int = 2048
vocab_size: int = 32000
n_layer: int = 32
n_layers: int = 32
# n_head in gpt-fast
n_heads: int = 32
dim: int = 4096
Expand Down Expand Up @@ -96,47 +96,47 @@ def from_name(cls, name: str):

transformer_configs = {
"CodeLlama-7b-Python-hf": dict(
block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000
block_size=16384, vocab_size=32000, n_layers=32, dim=4096, rope_base=1000000
),
"7B": dict(n_layer=32, n_heads=32, dim=4096),
"13B": dict(n_layer=40, n_heads=40, dim=5120),
"30B": dict(n_layer=60, n_heads=52, dim=6656),
"7B": dict(n_layers=32, n_heads=32, dim=4096),
"13B": dict(n_layers=40, n_heads=40, dim=5120),
"30B": dict(n_layers=60, n_heads=52, dim=6656),
"34B": dict(
n_layer=48,
n_layers=48,
n_heads=64,
dim=8192,
vocab_size=32000,
n_local_heads=8,
hidden_dim=22016,
rope_base=1000000,
), # CodeLlama-34B-Python-hf
"70B": dict(n_layer=80, n_heads=64, dim=8192, n_local_heads=8, hidden_dim=28672),
"70B": dict(n_layers=80, n_heads=64, dim=8192, n_local_heads=8, hidden_dim=28672),
"Mistral-7B": dict(
n_layer=32,
n_layers=32,
n_heads=32,
n_local_heads=8,
dim=4096,
hidden_dim=14336,
vocab_size=32000,
),
"Mistral-7B-Instruct-v0.1": dict(
n_layer=32,
n_layers=32,
n_heads=32,
n_local_heads=8,
dim=4096,
hidden_dim=14336,
vocab_size=32000,
),
"Mistral-7B-Instruct-v0.2": dict(
n_layer=32,
n_layers=32,
n_heads=32,
n_local_heads=8,
dim=4096,
hidden_dim=14336,
vocab_size=32000,
),
"stories15M": dict(n_layer=6, n_heads=6, dim=288),
"stories110M": dict(n_layer=12, n_heads=12, dim=768),
"stories15M": dict(n_layers=6, n_heads=6, dim=288),
"stories110M": dict(n_layers=12, n_heads=12, dim=768),
}


Expand Down Expand Up @@ -169,7 +169,7 @@ def __init__(self, config: ModelArgs) -> None:

self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList(
TransformerBlock(config) for _ in range(config.n_layer)
TransformerBlock(config) for _ in range(config.n_layers)
)
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
Expand Down

0 comments on commit d6f557c

Please sign in to comment.