Skip to content

Commit

Permalink
Add pipeline.embed support for Chronos-Bolt (#247)
Browse files Browse the repository at this point in the history
  • Loading branch information
abdulfatir authored Dec 22, 2024
1 parent 28e7b32 commit ad410c9
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 17 deletions.
11 changes: 3 additions & 8 deletions .github/workflows/eval-model.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ on:
- labeled # When a label is added to the PR

jobs:
evaluate-and-post:
evaluate-and-print:
if: contains(github.event.pull_request.labels.*.name, 'run-eval') # Only run if 'run-eval' label is added
runs-on: ubuntu-latest
env:
Expand All @@ -33,10 +33,5 @@ jobs:
- name: Run Eval Script
run: python scripts/evaluation/evaluate.py ci/evaluate/backtest_config.yaml $RESULTS_CSV --chronos-model-id=amazon/chronos-bolt-small --device=cpu --torch-dtype=float32

- name: Upload CSV
uses: actions/upload-artifact@v4
with:
name: eval-metrics
path: ${{ env.RESULTS_CSV }}
retention-days: 1
overwrite: true
- name: Print CSV
run: cat $RESULTS_CSV
68 changes: 59 additions & 9 deletions src/chronos/chronos_bolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

from .base import BaseChronosPipeline, ForecastType


logger = logging.getLogger(__file__)


Expand Down Expand Up @@ -240,13 +239,11 @@ def _init_weights(self, module):
):
module.output_layer.bias.data.zero_()

def forward(
self,
context: torch.Tensor,
mask: Optional[torch.Tensor] = None,
target: Optional[torch.Tensor] = None,
target_mask: Optional[torch.Tensor] = None,
) -> ChronosBoltOutput:
def encode(
self, context: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> Tuple[
torch.Tensor, Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor
]:
mask = (
mask.to(context.dtype)
if mask is not None
Expand Down Expand Up @@ -301,8 +298,21 @@ def forward(
attention_mask=attention_mask,
inputs_embeds=input_embeds,
)
hidden_states = encoder_outputs[0]

return encoder_outputs[0], loc_scale, input_embeds, attention_mask

def forward(
self,
context: torch.Tensor,
mask: Optional[torch.Tensor] = None,
target: Optional[torch.Tensor] = None,
target_mask: Optional[torch.Tensor] = None,
) -> ChronosBoltOutput:
batch_size = context.size(0)

hidden_states, loc_scale, input_embeds, attention_mask = self.encode(
context=context, mask=mask
)
sequence_output = self.decode(input_embeds, attention_mask, hidden_states)

quantile_preds_shape = (
Expand Down Expand Up @@ -426,6 +436,46 @@ def __init__(self, model: ChronosBoltModelForForecasting):
def quantiles(self) -> List[float]:
return self.model.config.chronos_config["quantiles"]

@torch.no_grad()
def embed(
self, context: Union[torch.Tensor, List[torch.Tensor]]
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Get encoder embeddings for the given time series.
Parameters
----------
context
Input series. This is either a 1D tensor, or a list
of 1D tensors, or a 2D tensor whose first dimension
is batch. In the latter case, use left-padding with
``torch.nan`` to align series of different lengths.
Returns
-------
embeddings, loc_scale
A tuple of two items: the encoder embeddings and the loc_scale,
i.e., the mean and std of the original time series.
The encoder embeddings are shaped (batch_size, num_patches + 1, d_model),
where num_patches is the number of patches in the time series
and the extra 1 is for the [REG] token (if used by the model).
"""
context_tensor = self._prepare_and_validate_context(context=context)
model_context_length = self.model.config.chronos_config["context_length"]

if context_tensor.shape[-1] > model_context_length:
context_tensor = context_tensor[..., -model_context_length:]

context_tensor = context_tensor.to(
device=self.model.device,
dtype=torch.float32,
)
embeddings, loc_scale, *_ = self.model.encode(context=context_tensor)
return embeddings.cpu(), (
loc_scale[0].squeeze(-1).cpu(),
loc_scale[1].squeeze(-1).cpu(),
)

def predict( # type: ignore[override]
self,
context: Union[torch.Tensor, List[torch.Tensor]],
Expand Down
44 changes: 44 additions & 0 deletions test/test_chronos_bolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,50 @@ def test_pipeline_predict_quantiles(
validate_tensor(mean, (1, prediction_length), dtype=torch.float32)


@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype):
pipeline = ChronosBoltPipeline.from_pretrained(
Path(__file__).parent / "dummy-chronos-bolt-model",
device_map="cpu",
torch_dtype=model_dtype,
)
d_model = pipeline.model.config.d_model
context = 10 * torch.rand(size=(4, 16)) + 10
context = context.to(dtype=input_dtype)

# the patch size of dummy model is 16, so only 1 patch is created
expected_embed_length = 1 + (
1 if pipeline.model.config.chronos_config["use_reg_token"] else 0
)

# input: tensor of shape (batch_size, context_length)

embedding, loc_scale = pipeline.embed(context)
validate_tensor(
embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype
)
validate_tensor(loc_scale[0], shape=(4,), dtype=torch.float32)
validate_tensor(loc_scale[1], shape=(4,), dtype=torch.float32)

# input: batch_size-long list of tensors of shape (context_length,)

embedding, loc_scale = pipeline.embed(list(context))
validate_tensor(
embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype
)
validate_tensor(loc_scale[0], shape=(4,), dtype=torch.float32)
validate_tensor(loc_scale[1], shape=(4,), dtype=torch.float32)

# input: tensor of shape (context_length,)
embedding, loc_scale = pipeline.embed(context[0, ...])
validate_tensor(
embedding, shape=(1, expected_embed_length, d_model), dtype=model_dtype
)
validate_tensor(loc_scale[0], shape=(1,), dtype=torch.float32)
validate_tensor(loc_scale[1], shape=(1,), dtype=torch.float32)


# The following tests have been taken from
# https://github.com/autogluon/autogluon/blob/f57beb26cb769c6e0d484a6af2b89eab8aee73a8/timeseries/tests/unittests/models/chronos/pipeline/test_chronos_bolt.py
# Author: Caner Turkmen <[email protected]>
Expand Down

0 comments on commit ad410c9

Please sign in to comment.