diff --git a/litgpt/data/base.py b/litgpt/data/base.py index f4ef68a818..cca6e8cb93 100644 --- a/litgpt/data/base.py +++ b/litgpt/data/base.py @@ -81,7 +81,9 @@ def __getitem__(self, idx: int) -> Dict[str, Tensor]: prompt = self.prompt_style.apply(prompt=example["instruction"], **example) encoded_prompt = self.tokenizer.encode(prompt, max_length=self.max_seq_length) encoded_response = self.tokenizer.encode(example["output"], bos=False, eos=True, max_length=self.max_seq_length) - encoded_prompt_and_response = torch.cat((encoded_prompt, encoded_response)).type(torch.int64)[: self.max_seq_length] + encoded_prompt_and_response = torch.cat((encoded_prompt, encoded_response)).type(torch.int64) + if self.max_seq_length > 0: # do not slice off last token when self.max_seq_length = -1 + encoded_prompt_and_response = encoded_prompt_and_response[: self.max_seq_length] # The labels are the full prompt with response, but with the prompt masked out labels = encoded_prompt_and_response.clone() diff --git a/tests/data/test_base.py b/tests/data/test_base.py index e948427dec..394200c6d1 100644 --- a/tests/data/test_base.py +++ b/tests/data/test_base.py @@ -9,7 +9,7 @@ @pytest.mark.parametrize("mask_prompt", [True, False]) @pytest.mark.parametrize("ignore_index", [-1, -100]) -@pytest.mark.parametrize("max_seq_length", [1000, 5]) +@pytest.mark.parametrize("max_seq_length", [1000, 5, -1]) def test_sft_dataset(max_seq_length, ignore_index, mask_prompt, mock_tokenizer): class Style(PromptStyle): def apply(self, prompt, **kwargs): @@ -34,8 +34,12 @@ def apply(self, prompt, **kwargs): torch.tensor([i, i, i, i, i, i, i, i, i, i, i, i, 66, 97, 114, 1]) if mask_prompt else expected_input_ids ) - assert torch.equal(dataset[0]["input_ids"], expected_input_ids[:max_seq_length]) - assert torch.equal(dataset[0]["labels"], expected_labels[:max_seq_length]) + if max_seq_length == -1: + assert torch.equal(dataset[0]["input_ids"], expected_input_ids) + assert torch.equal(dataset[0]["labels"], expected_labels) + else: + assert torch.equal(dataset[0]["input_ids"], expected_input_ids[:max_seq_length]) + assert torch.equal(dataset[0]["labels"], expected_labels[:max_seq_length]) @pytest.mark.parametrize("ignore_index", [-1, -100])