From 357cef41b324aa6108d28b4ba6e590bdfa501282 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Wed, 4 Dec 2024 16:07:54 +0100 Subject: [PATCH] add unit test --- test/test_utils.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 test/test_utils.py diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 0000000..e6842c1 --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,17 @@ +import pytest +import torch + +from chronos.utils import left_pad_and_stack_1D + + +@pytest.mark.parametrize("tensors", [ + list(map(torch.tensor, [[1, 2, 3], [5, 6]])), + list(map(torch.tensor, [[2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0, 10.0]])), +]) +def test_pad_and_stack(tensors: list): + stacked_and_padded = left_pad_and_stack_1D(tensors) + assert stacked_and_padded.shape == (len(tensors), max(len(t) for t in tensors)) + + ref = torch.concat(tensors).to(dtype=stacked_and_padded.dtype) + + assert torch.nanmean(stacked_and_padded) == torch.nanmean(ref)