Skip to content

Commit

Permalink
Add missing --ffn-expansion-factor to FLOPs calculator script (#35)
Browse files Browse the repository at this point in the history
* add ffn-expansion-factor to flops script

* Update calc_transformer_flops.py

---------

Co-authored-by: Quentin Anthony <[email protected]>
  • Loading branch information
haileyschoelkopf and Quentin-Anthony authored Aug 11, 2024
1 parent a17c66f commit ee051ec
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions calc/calc_transformer_flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def config_parser():
type=float,
default=1.0,
help='Ratio of kv heads to query heads used in model. 1.0 for MHA')
parser.add_argument("--ffn-expansion-factor", "-ff",
type=int,
default=4,
help='How much the MLP hidden size expands')
parser.add_argument("--moe",
action="store_true",
help='Whether our model is MoE')
Expand Down Expand Up @@ -102,9 +106,9 @@ def calc_flops(args):
attention_matrix_flops = iter_factor * 2 * args.num_layers * args.tokens * args.sequence_length * args.hidden_size
attention_over_values_flops = iter_factor * 2 * args.num_layers * args.tokens * args.sequence_length * args.hidden_size
linear_projection_flops = iter_factor * 2 * args.num_layers * args.tokens * args.hidden_size * args.hidden_size
ffn_flops = iter_factor * 2 * args.num_mlp_linears * args.ffn_expansion_factor * args.num_layers * args.tokens * args.hidden_size * args.hidden_size
ffn_flops = int(iter_factor * 2 * args.num_mlp_linears * args.ffn_expansion_factor) * args.num_layers * args.tokens * args.hidden_size * args.hidden_size
if args.swiglu:
ffn_flops = 3/2 * ffn_flops
ffn_flops = int(3/2 * ffn_flops)
# no activation checkpointing for embeddings
embedding_flops = 6 * args.tokens * args.hidden_size * args.vocab_size

Expand Down

0 comments on commit ee051ec

Please sign in to comment.