From 08f16d0cec563e559a66867ffc740d1f829815b6 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Sun, 15 Sep 2024 23:55:42 -0700 Subject: [PATCH] Update phi-3-mini lora export code and readme (#5327) Summary: Updated readme and example export code ahead of branch gut. Pull Request resolved: https://github.com/pytorch/executorch/pull/5327 Test Plan: - Exported manually Reviewed By: JacobSzwejbka Differential Revision: D62623250 Pulled By: dvorjackz fbshipit-source-id: 79ee3ad1d42ae961d94d225ee1e642c5bc540127 --- examples/models/phi-3-mini-lora/README.md | 4 +++- .../models/phi-3-mini-lora/export_model.py | 23 ++++++++++++++----- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/examples/models/phi-3-mini-lora/README.md b/examples/models/phi-3-mini-lora/README.md index 69564581af..92f23f137b 100644 --- a/examples/models/phi-3-mini-lora/README.md +++ b/examples/models/phi-3-mini-lora/README.md @@ -1,5 +1,7 @@ ## Summary -In this example, we export to ExecuTorch a model ([phi-3-mini](https://github.com/pytorch/executorch/tree/main/examples/models/phi-3-mini)) appended with attention and mlp LoRA layers. The model is exported to ExecuTorch for both inference and training. Note: the exported training model can only train at the moment. +In this example, we showcase how to export a model ([phi-3-mini](https://github.com/pytorch/executorch/tree/main/examples/models/phi-3-mini)) appended with LoRA layers to ExecuTorch. The model is exported to ExecuTorch for both inference and training. + +To see how you can use the model exported for training in a fully involved finetuning loop, please see our example on [LLM PTE Fintetuning](https://github.com/pytorch/executorch/tree/main/examples/llm_pte_finetuning). ## Instructions ### Step 1: [Optional] Install ExecuTorch dependencies diff --git a/examples/models/phi-3-mini-lora/export_model.py b/examples/models/phi-3-mini-lora/export_model.py index eb8fbc07fe..e6f291bd58 100644 --- a/examples/models/phi-3-mini-lora/export_model.py +++ b/examples/models/phi-3-mini-lora/export_model.py @@ -28,11 +28,13 @@ def __init__(self, model, loss): self.model = model self.loss = loss - def forward(self, input): + def forward(self, input: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: # Output is of the shape (seq_len, vocab_size). - output = self.model(input) - target = zeros((1, vocab_size), dtype=long) - return self.loss(output, target) + logits = self.model(input) + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + logits = logits.transpose(1, 2) + return self.loss(logits, labels) @no_grad() @@ -47,7 +49,11 @@ def export_phi3_mini_lora(model) -> None: model.eval() # 1. torch.export: Defines the program with the ATen operator set. print("Exporting to aten dialect") - example_args = (randint(0, 100, (1, 100), dtype=long),) + batch_size = 1 + vocab_size = 100 + seq_len = 10 + tokens = randint(0, vocab_size, (batch_size, seq_len), dtype=long) + example_args = (tokens,) with sdpa_kernel([SDPBackend.MATH]): aten_dialect: ExportedProgram = export(model, example_args) @@ -80,7 +86,12 @@ def export_phi3_mini_lora_training(model) -> None: print("Exporting phi3-mini with LoRA for training") # 1. torch.export: Defines the program with the ATen operator set. print("Exporting to aten dialect") - example_args = (randint(0, 100, (1, 100), dtype=long),) + batch_size = 1 + vocab_size = 100 + seq_len = 10 + tokens = randint(0, vocab_size, (batch_size, seq_len), dtype=long) + labels = tokens + example_args = (tokens, labels) with sdpa_kernel([SDPBackend.MATH]): exported_graph: ExportedProgram = export(model, example_args) print("Creating a joint forward-backwards graph for training")