-
Notifications
You must be signed in to change notification settings - Fork 308
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add pipeline.embed support for Chronos-Bolt (#247)
- Loading branch information
1 parent
28e7b32
commit ad410c9
Showing
3 changed files
with
106 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -132,6 +132,50 @@ def test_pipeline_predict_quantiles( | |
validate_tensor(mean, (1, prediction_length), dtype=torch.float32) | ||
|
||
|
||
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16]) | ||
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64]) | ||
def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype): | ||
pipeline = ChronosBoltPipeline.from_pretrained( | ||
Path(__file__).parent / "dummy-chronos-bolt-model", | ||
device_map="cpu", | ||
torch_dtype=model_dtype, | ||
) | ||
d_model = pipeline.model.config.d_model | ||
context = 10 * torch.rand(size=(4, 16)) + 10 | ||
context = context.to(dtype=input_dtype) | ||
|
||
# the patch size of dummy model is 16, so only 1 patch is created | ||
expected_embed_length = 1 + ( | ||
1 if pipeline.model.config.chronos_config["use_reg_token"] else 0 | ||
) | ||
|
||
# input: tensor of shape (batch_size, context_length) | ||
|
||
embedding, loc_scale = pipeline.embed(context) | ||
validate_tensor( | ||
embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype | ||
) | ||
validate_tensor(loc_scale[0], shape=(4,), dtype=torch.float32) | ||
validate_tensor(loc_scale[1], shape=(4,), dtype=torch.float32) | ||
|
||
# input: batch_size-long list of tensors of shape (context_length,) | ||
|
||
embedding, loc_scale = pipeline.embed(list(context)) | ||
validate_tensor( | ||
embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype | ||
) | ||
validate_tensor(loc_scale[0], shape=(4,), dtype=torch.float32) | ||
validate_tensor(loc_scale[1], shape=(4,), dtype=torch.float32) | ||
|
||
# input: tensor of shape (context_length,) | ||
embedding, loc_scale = pipeline.embed(context[0, ...]) | ||
validate_tensor( | ||
embedding, shape=(1, expected_embed_length, d_model), dtype=model_dtype | ||
) | ||
validate_tensor(loc_scale[0], shape=(1,), dtype=torch.float32) | ||
validate_tensor(loc_scale[1], shape=(1,), dtype=torch.float32) | ||
|
||
|
||
# The following tests have been taken from | ||
# https://github.com/autogluon/autogluon/blob/f57beb26cb769c6e0d484a6af2b89eab8aee73a8/timeseries/tests/unittests/models/chronos/pipeline/test_chronos_bolt.py | ||
# Author: Caner Turkmen <[email protected]> | ||
|