-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Split sharded Llama dataset exporting and loading in export scripts #327
Split sharded Llama dataset exporting and loading in export scripts #327
Conversation
llama_config.tensor_parallelism_size = attn_q_weight.shard_count | ||
llama_config = LlamaModelConfig( | ||
hp, | ||
tensor_parallelism_size=args.tensor_parallelism_size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will not work as it assumes that the passed arg is guaranteed to match the irpa
file. We should look into plumming sharding into the saved hyper parameters then extract. I also dislike the hack in line 84
however the solution is not to introduce mismatches.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I will add it to the exported irpa hyper parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed it.
llama_config.tensor_parallelism_size = attn_q_weight.shard_count | ||
llama_config = LlamaModelConfig( | ||
hp, | ||
tensor_parallelism_size=dataset.properties["tensor_parallelism_size"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a check prior for when tensor_parallelism_size
has no value and in those cases default to 1
. We should maintain that the old non-sharded irpa
files still work.
llama_config.tensor_parallelism_size = attn_q_weight.shard_count | ||
llama_config = LlamaModelConfig( | ||
hp, | ||
tensor_parallelism_size=dataset.properties["tensor_parallelism_size"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just declare it outside of llama_config
. Giant comprehensions are bad for readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
2c22d38
to
38c3410
Compare
Separate the 2 steps. We need exported irpa files for the IREE module anyway.
3fbb144
to
1dc7897
Compare
Separate the 2 steps. We need exported irpa files for the IREE module anyway.