Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split sharded Llama dataset exporting and loading in export scripts #327

Merged
merged 4 commits into from
Oct 25, 2024

Conversation

sogartar
Copy link
Contributor

Separate the 2 steps. We need exported irpa files for the IREE module anyway.

llama_config.tensor_parallelism_size = attn_q_weight.shard_count
llama_config = LlamaModelConfig(
hp,
tensor_parallelism_size=args.tensor_parallelism_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will not work as it assumes that the passed arg is guaranteed to match the irpa file. We should look into plumming sharding into the saved hyper parameters then extract. I also dislike the hack in line 84 however the solution is not to introduce mismatches.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will add it to the exported irpa hyper parameters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed it.

llama_config.tensor_parallelism_size = attn_q_weight.shard_count
llama_config = LlamaModelConfig(
hp,
tensor_parallelism_size=dataset.properties["tensor_parallelism_size"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a check prior for when tensor_parallelism_size has no value and in those cases default to 1. We should maintain that the old non-sharded irpa files still work.

llama_config.tensor_parallelism_size = attn_q_weight.shard_count
llama_config = LlamaModelConfig(
hp,
tensor_parallelism_size=dataset.properties["tensor_parallelism_size"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just declare it outside of llama_config. Giant comprehensions are bad for readability.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@sogartar sogartar force-pushed the sharded-llama-dataset-exporting branch from 2c22d38 to 38c3410 Compare October 25, 2024 07:48
@sogartar sogartar force-pushed the sharded-llama-dataset-exporting branch from 3fbb144 to 1dc7897 Compare October 25, 2024 19:10
@sogartar sogartar enabled auto-merge (squash) October 25, 2024 19:11
@sogartar sogartar merged commit 1aeb3a8 into nod-ai:main Oct 25, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants