From 5f5a6676e6395a1e1a76913c8715a5cea9a79346 Mon Sep 17 00:00:00 2001 From: Sander Land <48946947+sanderland@users.noreply.github.com> Date: Fri, 23 Aug 2024 11:15:43 +0200 Subject: [PATCH 1/4] Preserve eos in encoding when max_seq_length = -1 --- litgpt/data/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/litgpt/data/base.py b/litgpt/data/base.py index f4ef68a818..cca6e8cb93 100644 --- a/litgpt/data/base.py +++ b/litgpt/data/base.py @@ -81,7 +81,9 @@ def __getitem__(self, idx: int) -> Dict[str, Tensor]: prompt = self.prompt_style.apply(prompt=example["instruction"], **example) encoded_prompt = self.tokenizer.encode(prompt, max_length=self.max_seq_length) encoded_response = self.tokenizer.encode(example["output"], bos=False, eos=True, max_length=self.max_seq_length) - encoded_prompt_and_response = torch.cat((encoded_prompt, encoded_response)).type(torch.int64)[: self.max_seq_length] + encoded_prompt_and_response = torch.cat((encoded_prompt, encoded_response)).type(torch.int64) + if self.max_seq_length > 0: # do not slice off last token when self.max_seq_length = -1 + encoded_prompt_and_response = encoded_prompt_and_response[: self.max_seq_length] # The labels are the full prompt with response, but with the prompt masked out labels = encoded_prompt_and_response.clone() From 796cea31430c10aee1efb99f31aedad335b3242d Mon Sep 17 00:00:00 2001 From: Sander Land Date: Sat, 24 Aug 2024 20:57:20 +0200 Subject: [PATCH 2/4] patch --- test.py | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 test.py diff --git a/test.py b/test.py new file mode 100644 index 0000000000..dde64795c3 --- /dev/null +++ b/test.py @@ -0,0 +1,5 @@ +from litgpt import LLM + +llm = LLM.load("microsoft/phi-2") +text = llm.generate("Fix the spelling: Every fall, the familly goes to the mountains.") +print(text) From d2e164c615ec803fa5cdfeead9e638a5a3bd8ef2 Mon Sep 17 00:00:00 2001 From: Sander Land Date: Sat, 24 Aug 2024 20:57:44 +0200 Subject: [PATCH 3/4] edit tests/data/test_base.py --- tests/data/test_base.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/data/test_base.py b/tests/data/test_base.py index e948427dec..394200c6d1 100644 --- a/tests/data/test_base.py +++ b/tests/data/test_base.py @@ -9,7 +9,7 @@ @pytest.mark.parametrize("mask_prompt", [True, False]) @pytest.mark.parametrize("ignore_index", [-1, -100]) -@pytest.mark.parametrize("max_seq_length", [1000, 5]) +@pytest.mark.parametrize("max_seq_length", [1000, 5, -1]) def test_sft_dataset(max_seq_length, ignore_index, mask_prompt, mock_tokenizer): class Style(PromptStyle): def apply(self, prompt, **kwargs): @@ -34,8 +34,12 @@ def apply(self, prompt, **kwargs): torch.tensor([i, i, i, i, i, i, i, i, i, i, i, i, 66, 97, 114, 1]) if mask_prompt else expected_input_ids ) - assert torch.equal(dataset[0]["input_ids"], expected_input_ids[:max_seq_length]) - assert torch.equal(dataset[0]["labels"], expected_labels[:max_seq_length]) + if max_seq_length == -1: + assert torch.equal(dataset[0]["input_ids"], expected_input_ids) + assert torch.equal(dataset[0]["labels"], expected_labels) + else: + assert torch.equal(dataset[0]["input_ids"], expected_input_ids[:max_seq_length]) + assert torch.equal(dataset[0]["labels"], expected_labels[:max_seq_length]) @pytest.mark.parametrize("ignore_index", [-1, -100]) From 6124ff35d12a1410913cb29cc8b30d945b7da39f Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 25 Aug 2024 04:57:33 -0400 Subject: [PATCH 4/4] Delete test.py --- test.py | 5 ----- 1 file changed, 5 deletions(-) delete mode 100644 test.py diff --git a/test.py b/test.py deleted file mode 100644 index dde64795c3..0000000000 --- a/test.py +++ /dev/null @@ -1,5 +0,0 @@ -from litgpt import LLM - -llm = LLM.load("microsoft/phi-2") -text = llm.generate("Fix the spelling: Every fall, the familly goes to the mountains.") -print(text)