diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index a31944a3ef69..537a3cf198f0 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1407,6 +1407,8 @@ - sections: - local: model_doc/autoformer title: Autoformer + - local: model_doc/ctsm + title: CTSM - local: model_doc/informer title: Informer - local: model_doc/patchtsmixer diff --git a/docs/source/en/model_doc/ctsm.md b/docs/source/en/model_doc/ctsm.md new file mode 100644 index 000000000000..8d891a07f633 --- /dev/null +++ b/docs/source/en/model_doc/ctsm.md @@ -0,0 +1,122 @@ + +*This model was released on 2025-11-25 and added to Hugging Face Transformers on 2026-04-17.* + +
+
+ PyTorch + FlashAttention + SDPA +
+
+ +# CTSM + +## Overview + +The Cisco Time Series Model (CTSM) was proposed in [Cisco Time Series Model Technical Report](https://huggingface.co/papers/2511.19841) by Liang Gou, Archit Khare, Praneet Pabolu, Prachi Patel, Joseph Ross, Hercy Shen, Yuhan (Ellen) Song, Jingze Sun, Kristal Curtis, Vedant Dharnidharka, Abhinav Mathur and Hao Yang. + +CTSM is a decoder-only univariate zero-shot forecasting foundation model. Its central idea is a **multi-resolution context**: instead of consuming a single-scale history, each forecast conditions on two aligned streams — a coarse low-frequency stream (e.g. 512 hourly points) and a fine high-frequency stream (e.g. 512 minutely points), with the resolution ratio fixed to 60. A learnable **special token** separates the two streams and learned **resolution embeddings** are added to the token stream to distinguish them. The coarse stream lets the model see week-over-week structure without giving up fine-grained recent detail; as the paper puts it, "more complex multiresolution architectures would require a context length of 30,720 (30 times as long as ours) to cover the same time range." + +The abstract from the paper is the following: + +*We introduce the Cisco Time Series Model, a univariate zero-shot forecaster. This time series foundation model is the result of a general architectural innovation to a time series model enabling it to accept multiresolution input, applied to a popular decoder-only time series model (TimesFM). The resulting multiresolution decoder-only model is trained on over 300B unique data points, with more than half coming from the observability domain. Quantitative and qualitative evaluations demonstrate that the resulting model achieves superior performance on observability datasets while retaining very similar performance on a standard general-purpose forecasting benchmark (GIFT-Eval), and suggest that the multiresolution structure enables the model to make more accurate predictions on long context input.* + +### Architecture + +The backbone follows TimesFM 2.0: patching (patch length 32) + a residual-block input tokenizer + decoder-only transformer layers with per-dimension learnable query scaling + a residual-block horizon head. CTSM adds, on top: + +- A **special token** inserted between the coarse and fine patch streams, so the input is `[coarse₁, …, coarse₁₆, SPECIAL, fine₁, …, fine₁₆]`. +- **Resolution embeddings** (3-way: coarse / special / fine) added to each token before the transformer stack. +- **Stream-level normalization**: each stream is standardized independently over its non-padded context, and the fine-stream statistics are used to rescale the forecast. +- A **frequency embedding** inherited from TimesFM, added to every token. + +The 250M **CTSM 1.0** release checkpoint additionally introduces (over the 500M `1.0-preview` described in the paper): + +- **Rotary position embeddings (RoPE)** applied to query/key inside attention. +- **Bidirectional attention over the coarse block** — tokens in the coarse segment attend both ways within that segment, while the fine segment remains causal. +- **15-quantile prediction** (levels 0.01–0.99) instead of 9. +- **Short-context training** (1/3 of training samples drawn with `|fine| ∈ [10, 511]`) for better robustness when less history is available. +- Trained from scratch (not continued pre-training from TimesFM 2.0) on ~2× more internal observability data. + +### Inference + +For `horizon_len > config.horizon_length`, [`CtsmModelForPrediction`] runs an autoregressive multi-resolution decode loop, using a [`DynamicCache`] by default (opt out with `use_cache=False`). Each step feeds only the newly-appended fine patches through the stack and attends to cached K/V for every earlier position. Stream-normalization statistics are frozen to their step-1 values so that cached K/V remains valid; the coarse block is pinned and the cache is rebuilt if the concatenated sequence would outgrow `max_position_embeddings`. + +The checkpoint can be found at [`cisco-ai/cisco-time-series-model-1.0`](https://huggingface.co/cisco-ai/cisco-time-series-model-1.0). The original inference code is at [github.com/splunk/cisco-time-series-model](https://github.com/splunk/cisco-time-series-model). + +This model was contributed by [kashif](https://huggingface.co/kashif). + +## Usage + +Pass a list of fine-resolution time series (e.g. minute-level); the coarse stream is built automatically by mean-aggregating consecutive blocks of `config.agg_factor` points. + +```python +import numpy as np +import torch +from transformers import CtsmModelForPrediction + + +model = CtsmModelForPrediction.from_pretrained("cisco-ai/cisco-time-series-model-1.0", device_map="auto") + +# ~8.5 hours of 1-minute data; the model will build a 512-hour coarse context by aggregation. +series = np.sin(np.linspace(0, 200, 512 * 60)).astype(np.float32) +past_values = [torch.tensor(series, device=model.device)] + +with torch.no_grad(): + outputs = model(past_values=past_values, horizon_len=128) + +point_forecast = outputs.mean_predictions # (batch, horizon_len) +quantile_forecast = outputs.full_predictions # (batch, horizon_len, 1 + num_quantiles) +``` + +If you already have a coarse stream (e.g. pre-computed 1-hour roll-ups that go further back than you have 1-minute data for), pass `(coarse, fine)` pairs directly: + +```python +coarse = torch.tensor(hourly_series, dtype=torch.float32) # up to 512 points +fine = torch.tensor(minutely_series, dtype=torch.float32) # up to 512 points +outputs = model(past_values=[(coarse, fine)], horizon_len=128) +``` + +For `horizon_len > 128`, the model decodes autoregressively and extends the output accordingly. + +## CtsmConfig + +[[autodoc]] CtsmConfig + +## CtsmModel + +[[autodoc]] CtsmModel + - forward + +## CtsmModelForPrediction + +[[autodoc]] CtsmModelForPrediction + - forward + +## Citation + +```bibtex +@misc{gou2025ciscotimeseriesmodel, + title={Cisco Time Series Model Technical Report}, + author={Liang Gou and Archit Khare and Praneet Pabolu and Prachi Patel and Joseph Ross and Hercy Shen and Yuhan Song and Jingze Sun and Kristal Curtis and Vedant Dharnidharka and Abhinav Mathur and Hao Yang}, + year={2025}, + eprint={2511.19841}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2511.19841} +} +``` diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index acc5e2fdeac0..89e35fbd4e8c 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -81,6 +81,7 @@ from .cpmant import * from .csm import * from .ctrl import * + from .ctsm import * from .cvt import * from .cwm import * from .d_fine import * diff --git a/src/transformers/models/auto/auto_mappings.py b/src/transformers/models/auto/auto_mappings.py index 0ec3cdf700ec..89d4a8599041 100644 --- a/src/transformers/models/auto/auto_mappings.py +++ b/src/transformers/models/auto/auto_mappings.py @@ -108,6 +108,7 @@ ("csm", "CsmConfig"), ("csm_depth_decoder_model", "CsmDepthDecoderConfig"), ("ctrl", "CTRLConfig"), + ("ctsm", "CtsmConfig"), ("cvt", "CvtConfig"), ("cwm", "CwmConfig"), ("d_fine", "DFineConfig"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 44d83634d28e..d12f14b9615d 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -99,6 +99,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("cpmant", "CpmAntModel"), ("csm", "CsmForConditionalGeneration"), ("ctrl", "CTRLModel"), + ("ctsm", "CtsmModel"), ("cvt", "CvtModel"), ("cwm", "CwmModel"), ("d_fine", "DFineModel"), @@ -1811,6 +1812,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES = OrderedDict( [ + ("ctsm", "CtsmModelForPrediction"), ("timesfm", "TimesFmModelForPrediction"), ("timesfm2_5", "TimesFm2_5ModelForPrediction"), ] diff --git a/src/transformers/models/ctsm/__init__.py b/src/transformers/models/ctsm/__init__.py new file mode 100644 index 000000000000..e5979d7ba3db --- /dev/null +++ b/src/transformers/models/ctsm/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2026 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_ctsm import * + from .modeling_ctsm import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/ctsm/configuration_ctsm.py b/src/transformers/models/ctsm/configuration_ctsm.py new file mode 100644 index 000000000000..f2f62a97d0a7 --- /dev/null +++ b/src/transformers/models/ctsm/configuration_ctsm.py @@ -0,0 +1,118 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/ctsm/modular_ctsm.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_ctsm.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from huggingface_hub.dataclasses import strict + +from ...configuration_utils import PreTrainedConfig +from ...modeling_rope_utils import RopeParameters +from ...utils import auto_docstring + + +@auto_docstring(checkpoint="cisco-ai/cisco-time-series-model-1.0") +@strict +class CtsmConfig(PreTrainedConfig): + r""" + patch_length (`int`, *optional*, defaults to 32): + Length of one patch in the input sequence for each resolution stream. + context_length (`int`, *optional*, defaults to 512): + Length of the input context for each resolution stream. + horizon_length (`int`, *optional*, defaults to 128): + Length of the prediction horizon produced per autoregressive step. + freq_size (`int`, *optional*, defaults to 3): + Number of frequency embeddings. + tolerance (`float`, *optional*, defaults to 1e-06): + Numerical tolerance used in normalization. + pad_val (`float`, *optional*, defaults to 1123581321.0): + Sentinel value marking padded positions in the input series. + num_hidden_layers (`int`, *optional*, defaults to 25): + Number of decoder layers. + quantiles (`list[float]`, *optional*, defaults to 15 values between 0.01 and 0.99): + Quantile levels predicted by the model. + use_positional_embedding (`bool`, *optional*, defaults to `False`): + CTSM uses rotary position embeddings and does not add sinusoidal positional embeddings. + use_resolution_embeddings (`bool`, *optional*, defaults to `True`): + Whether to add a learned embedding per resolution bucket (coarse / special / fine). + use_special_token (`bool`, *optional*, defaults to `True`): + Whether to insert a learned special token between the coarse and fine streams. + num_resolutions (`int`, *optional*, defaults to 3): + Number of resolution embeddings (coarse, special token, fine). + agg_factor (`int`, *optional*, defaults to 60): + Aggregation factor between fine and coarse resolutions (e.g. 60 minutes -> 1 hour). + max_position_embeddings (`int`, *optional*, defaults to 1025): + Maximum number of patches in the concatenated sequence (coarse + special + fine). + rope_parameters (`dict`, *optional*): + Rotary position embedding parameters. Defaults to `{"rope_type": "default", "rope_theta": 10000.0}`. + + Example: + + ```python + >>> from transformers import CtsmConfig, CtsmModelForPrediction + + >>> configuration = CtsmConfig() + >>> model = CtsmModelForPrediction(configuration) + >>> configuration = model.config + ``` + """ + + model_type = "ctsm" + keys_to_ignore_at_inference = [] + is_encoder_decoder = False + + patch_length: int = 32 + context_length: int = 512 + horizon_length: int = 128 + freq_size: int = 3 + + num_hidden_layers: int = 25 + hidden_size: int = 1280 + intermediate_size: int = 1280 + head_dim: int = 80 + num_attention_heads: int = 16 + tolerance: float = 1e-6 + rms_norm_eps: float = 1e-6 + quantiles: list[float] | tuple[float, ...] = ( + 0.01, + 0.05, + 0.1, + 0.2, + 0.25, + 0.3, + 0.4, + 0.5, + 0.6, + 0.7, + 0.75, + 0.8, + 0.9, + 0.95, + 0.99, + ) + pad_val: float = 1123581321.0 + attention_dropout: float | int = 0.0 + use_positional_embedding: bool = False + initializer_range: float = 0.02 + use_resolution_embeddings: bool = True + use_special_token: bool = True + num_resolutions: int = 3 + agg_factor: int = 60 + max_position_embeddings: int = 1025 + rope_parameters: RopeParameters | dict | None = None + + +__all__ = ["CtsmConfig"] diff --git a/src/transformers/models/ctsm/convert_ctsm_original_to_hf.py b/src/transformers/models/ctsm/convert_ctsm_original_to_hf.py new file mode 100644 index 000000000000..0b618547b387 --- /dev/null +++ b/src/transformers/models/ctsm/convert_ctsm_original_to_hf.py @@ -0,0 +1,209 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert a Cisco Time Series Model (CTSM) 1.0 checkpoint to the transformers format. + +Sample usage: + +``` +python src/transformers/models/ctsm/convert_ctsm_original_to_hf.py \ + --output_dir /output/path \ + --huggingface_repo_id cisco-ai/cisco-time-series-model-1.0 +``` +""" + +import argparse +import os + +import torch +from huggingface_hub import snapshot_download + +from transformers import CtsmConfig, CtsmModelForPrediction + + +CTSM_CHECKPOINT_FILENAME = "torch_model.pt" + +# CTSM 1.0 public checkpoint ships 15 quantiles spanning [0.01, 0.99]. +CTSM_1_0_QUANTILES = [0.01, 0.05, 0.1, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.75, 0.8, 0.9, 0.95, 0.99] + + +def _layer_mapping(num_layers: int, hidden_size: int) -> dict[str, str | tuple[str, int]]: + """Return a mapping `old_key -> new_key` (or `(new_prefix, split_idx)` for fused QKV).""" + mapping: dict[str, str | tuple[str, int]] = { + # input tokenizer (residual block) + "input_ff_layer.hidden_layer.0.weight": "model.input_ff_layer.input_layer.weight", + "input_ff_layer.hidden_layer.0.bias": "model.input_ff_layer.input_layer.bias", + "input_ff_layer.output_layer.weight": "model.input_ff_layer.output_layer.weight", + "input_ff_layer.output_layer.bias": "model.input_ff_layer.output_layer.bias", + "input_ff_layer.residual_layer.weight": "model.input_ff_layer.residual_layer.weight", + "input_ff_layer.residual_layer.bias": "model.input_ff_layer.residual_layer.bias", + # frequency, resolution and special token embeddings + "freq_emb.weight": "model.freq_emb.weight", + "multi_resolution.weight": "model.multi_resolution.weight", + "special_token": "model.special_token", + # horizon head (residual block) + "horizon_ff_layer.hidden_layer.0.weight": "horizon_ff_layer.input_layer.weight", + "horizon_ff_layer.hidden_layer.0.bias": "horizon_ff_layer.input_layer.bias", + "horizon_ff_layer.output_layer.weight": "horizon_ff_layer.output_layer.weight", + "horizon_ff_layer.output_layer.bias": "horizon_ff_layer.output_layer.bias", + "horizon_ff_layer.residual_layer.weight": "horizon_ff_layer.residual_layer.weight", + "horizon_ff_layer.residual_layer.bias": "horizon_ff_layer.residual_layer.bias", + } + + layer_template = { + # fused qkv -> split into q, k, v below + "stacked_transformer.layers.{i}.self_attn.qkv_proj.weight": ("model.layers.{i}.self_attn", "qkv_weight"), + "stacked_transformer.layers.{i}.self_attn.qkv_proj.bias": ("model.layers.{i}.self_attn", "qkv_bias"), + "stacked_transformer.layers.{i}.self_attn.o_proj.weight": "model.layers.{i}.self_attn.o_proj.weight", + "stacked_transformer.layers.{i}.self_attn.o_proj.bias": "model.layers.{i}.self_attn.o_proj.bias", + "stacked_transformer.layers.{i}.self_attn.scaling": "model.layers.{i}.self_attn.scaling", + "stacked_transformer.layers.{i}.mlp.gate_proj.weight": "model.layers.{i}.mlp.gate_proj.weight", + "stacked_transformer.layers.{i}.mlp.gate_proj.bias": "model.layers.{i}.mlp.gate_proj.bias", + "stacked_transformer.layers.{i}.mlp.down_proj.weight": "model.layers.{i}.mlp.down_proj.weight", + "stacked_transformer.layers.{i}.mlp.down_proj.bias": "model.layers.{i}.mlp.down_proj.bias", + "stacked_transformer.layers.{i}.mlp.layer_norm.weight": "model.layers.{i}.mlp.layer_norm.weight", + "stacked_transformer.layers.{i}.mlp.layer_norm.bias": "model.layers.{i}.mlp.layer_norm.bias", + "stacked_transformer.layers.{i}.input_layernorm.weight": "model.layers.{i}.input_layernorm.weight", + } + for i in range(num_layers): + for old, new in layer_template.items(): + mapping[old.format(i=i)] = new.format(i=i) if isinstance(new, str) else (new[0].format(i=i), new[1]) + return mapping + + +def convert_state_dict(original_sd: dict[str, torch.Tensor], hidden_size: int) -> dict[str, torch.Tensor]: + """Rewrite the original CTSM state dict into the transformers key layout.""" + num_layers = 0 + for key in original_sd: + if key.startswith("stacked_transformer.layers."): + idx = int(key.split(".")[2]) + num_layers = max(num_layers, idx + 1) + if num_layers == 0: + raise ValueError("No transformer layers found in the original checkpoint.") + + mapping = _layer_mapping(num_layers, hidden_size) + new_sd: dict[str, torch.Tensor] = {} + missing: list[str] = [] + for old_key, target in mapping.items(): + if old_key not in original_sd: + missing.append(old_key) + continue + tensor = original_sd[old_key] + if isinstance(target, tuple): + prefix, kind = target + if kind == "qkv_weight": + q, k, v = tensor.split(hidden_size, dim=0) + new_sd[f"{prefix}.q_proj.weight"] = q.clone() + new_sd[f"{prefix}.k_proj.weight"] = k.clone() + new_sd[f"{prefix}.v_proj.weight"] = v.clone() + elif kind == "qkv_bias": + q, k, v = tensor.split(hidden_size, dim=0) + new_sd[f"{prefix}.q_proj.bias"] = q.clone() + new_sd[f"{prefix}.k_proj.bias"] = k.clone() + new_sd[f"{prefix}.v_proj.bias"] = v.clone() + else: + raise ValueError(f"Unknown fused projection kind: {kind}") + else: + new_sd[target] = tensor.clone() + if missing: + print(f"[warn] {len(missing)} expected key(s) missing from the original checkpoint (first 5): {missing[:5]}") + return new_sd + + +def _infer_config_from_state_dict(original_sd: dict[str, torch.Tensor]) -> CtsmConfig: + """Infer a `CtsmConfig` from an original CTSM 1.0 state dict.""" + num_layers = 1 + max( + (int(k.split(".")[2]) for k in original_sd if k.startswith("stacked_transformer.layers.")), + default=-1, + ) + hidden_size = original_sd["input_ff_layer.output_layer.weight"].shape[0] + qkv_out = original_sd["stacked_transformer.layers.0.self_attn.qkv_proj.weight"].shape[0] + # qkv is [3 * num_heads * head_dim, hidden_size] — split evenly. + num_heads = 16 + head_dim = qkv_out // (3 * num_heads) + horizon_out = original_sd["horizon_ff_layer.output_layer.weight"].shape[0] + horizon_length = 128 + num_outputs = horizon_out // horizon_length + quantiles = ( + CTSM_1_0_QUANTILES if num_outputs - 1 == len(CTSM_1_0_QUANTILES) else [0.1 * i for i in range(1, num_outputs)] + ) + + return CtsmConfig( + num_hidden_layers=num_layers, + hidden_size=hidden_size, + intermediate_size=hidden_size, + num_attention_heads=num_heads, + head_dim=head_dim, + patch_length=32, + context_length=512, + horizon_length=horizon_length, + quantiles=quantiles, + use_positional_embedding=False, + use_resolution_embeddings="multi_resolution.weight" in original_sd, + use_special_token="special_token" in original_sd, + agg_factor=60, + max_position_embeddings=1025, + ) + + +def write_model(output_dir: str, huggingface_repo_id: str, safe_serialization: bool = True) -> None: + os.makedirs(output_dir, exist_ok=True) + local_dir = snapshot_download(repo_id=huggingface_repo_id, allow_patterns=[CTSM_CHECKPOINT_FILENAME]) + checkpoint_path = os.path.join(local_dir, CTSM_CHECKPOINT_FILENAME) + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"{CTSM_CHECKPOINT_FILENAME} not found in {huggingface_repo_id}") + + print(f"Loading original checkpoint from {checkpoint_path}") + original_sd = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + + config = _infer_config_from_state_dict(original_sd) + print( + f"Inferred CtsmConfig: layers={config.num_hidden_layers} hidden={config.hidden_size} " + f"heads={config.num_attention_heads} head_dim={config.head_dim} quantiles={len(config.quantiles)}" + ) + config.save_pretrained(output_dir) + + model = CtsmModelForPrediction(config) + converted_sd = convert_state_dict(original_sd, hidden_size=config.hidden_size) + + incompatible = model.load_state_dict(converted_sd, strict=False) + if incompatible.missing_keys: + print(f"[warn] missing keys after load: {incompatible.missing_keys[:10]}") + if incompatible.unexpected_keys: + print(f"[warn] unexpected keys after load: {incompatible.unexpected_keys[:10]}") + + model.save_pretrained(output_dir, safe_serialization=safe_serialization) + print(f"Saved transformers checkpoint to {output_dir}") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output_dir", required=True, help="Where to write the converted HF checkpoint.") + parser.add_argument( + "--huggingface_repo_id", + default="cisco-ai/cisco-time-series-model-1.0", + help="Original CTSM repo on the Hub.", + ) + parser.add_argument("--safe_serialization", type=bool, default=True) + args = parser.parse_args() + + write_model( + output_dir=args.output_dir, + huggingface_repo_id=args.huggingface_repo_id, + safe_serialization=args.safe_serialization, + ) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/ctsm/modeling_ctsm.py b/src/transformers/models/ctsm/modeling_ctsm.py new file mode 100644 index 000000000000..e0ae7cbff49e --- /dev/null +++ b/src/transformers/models/ctsm/modeling_ctsm.py @@ -0,0 +1,1349 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/ctsm/modular_ctsm.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_ctsm.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ... import initialization as init +from ...cache_utils import Cache, DynamicCache +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub +from ...modeling_outputs import BaseModelOutput +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from .configuration_ctsm import CtsmConfig + + +@dataclass +@auto_docstring +class CtsmOutput(BaseModelOutput): + r""" + loc (`torch.Tensor` of shape `(batch_size,)`): + Stream-level mean used to normalize the fine-resolution context, reused to rescale the final forecast. + scale (`torch.Tensor` of shape `(batch_size,)`): + Stream-level standard deviation of the fine-resolution context. + loc_coarse (`torch.Tensor` of shape `(batch_size,)`): + Stream-level mean used to normalize the coarse-resolution context. + scale_coarse (`torch.Tensor` of shape `(batch_size,)`): + Stream-level standard deviation of the coarse-resolution context. + num_coarse_patches (`int`): + Number of patches (including the optional special token) preceding the fine-resolution block. + num_fine_patches (`int`): + Number of patches in the fine-resolution block of the concatenated sequence. + past_key_values (`Cache`, *optional*): + Key/value cache for the concatenated `[coarse, special, fine]` sequence. Populated when the + caller passes `use_cache=True` (and re-used across autoregressive decode steps). Typically only + the long-horizon AR loop in [`CtsmModelForPrediction`] needs this. + """ + + loc: torch.Tensor | None = None + scale: torch.Tensor | None = None + + loc_coarse: torch.Tensor | None = None + scale_coarse: torch.Tensor | None = None + num_coarse_patches: int | None = None + num_fine_patches: int | None = None + past_key_values: Cache | None = None + + +@dataclass +@auto_docstring +class CtsmOutputForPrediction(BaseModelOutput): + r""" + mean_predictions (`torch.Tensor` of shape `(batch_size, horizon_length)`): + Point forecasts over the fine-resolution horizon. + full_predictions (`torch.Tensor` of shape `(batch_size, horizon_length, 1 + num_quantiles)`): + Concatenation of the mean prediction and the quantile predictions along the last axis. + loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `future_values` is provided): + Training loss combining MSE of the mean forecast and quantile loss when fine-resolution targets are supplied. + """ + + mean_predictions: torch.Tensor | None = None + full_predictions: torch.Tensor | None = None + loss: torch.Tensor | float | None = None + + +class CtsmResidualBlock(nn.Module): + """Ctsm residual block.""" + + def __init__(self, input_dims, hidden_dims, output_dims): + super().__init__() + self.input_dims = input_dims + self.hidden_dims = hidden_dims + self.output_dims = output_dims + + self.input_layer = nn.Linear(input_dims, hidden_dims) + self.activation = nn.SiLU() + self.output_layer = nn.Linear(hidden_dims, output_dims) + self.residual_layer = nn.Linear(input_dims, output_dims) + + def forward(self, x): + hidden = self.input_layer(x) + hidden = self.activation(hidden) + output = self.output_layer(hidden) + residual = self.residual_layer(x) + return output + residual + + +class CtsmRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: CtsmConfig, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + + @staticmethod + def compute_default_rope_parameters( + config: CtsmConfig | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def simple_eager_attention_forward( + module: nn.Module, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float | int = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@use_kernel_func_from_hub("rotary_pos_emb") +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class CtsmAttention(nn.Module): + """TimesFM 2.0 style attention with learnable per-dimension Q scaling and rotary position embeddings. + + Supports an optional `past_key_values` cache so that, during long-horizon autoregressive decoding, + each step only needs to compute K/V for the newly-appended fine patches and attends to the + previously-cached K/V for every earlier position. + """ + + def __init__(self, config: CtsmConfig, layer_idx: int): + super().__init__() + self.config = config + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.layer_idx = layer_idx + + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_dim = config.head_dim + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_heads * self.head_dim + self.scaling = nn.Parameter(torch.empty((self.head_dim,))) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim) + self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim) + self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size) + + def _scale_query(self, query: torch.Tensor) -> torch.Tensor: + scale = F.softplus(self.scaling).mul(1.442695041 / math.sqrt(self.head_dim)) + return query * scale[None, None, None, :] + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + past_key_values: Cache | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + query_states = self._scale_query(query_states) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, simple_eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=1.0, + **kwargs, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class CtsmMLP(nn.Module): + """Pax MLP in pytorch.""" + + def __init__(self, config: CtsmConfig): + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + + self.gate_proj = nn.Linear(hidden_size, intermediate_size) + self.down_proj = nn.Linear(intermediate_size, hidden_size) + self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6) + + def forward(self, x, paddings=None): + gate_inp = self.layer_norm(x) + gate = self.gate_proj(gate_inp) + gate = F.relu(gate) + outputs = self.down_proj(gate) + if paddings is not None: + outputs = outputs * (1.0 - paddings[:, :, None]) + return outputs + x + + +@use_kernel_forward_from_hub("RMSNorm") +class CtsmRMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + CtsmRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class CtsmDecoderLayer(nn.Module): + """CTSM transformer block: attention with RoPE followed by TimesFM 2.0 MLP with padding masking.""" + + def __init__(self, config: CtsmConfig, layer_idx: int): + super().__init__() + self.self_attn = CtsmAttention(config, layer_idx=layer_idx) + self.mlp = CtsmMLP(config) + self.input_layernorm = CtsmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + paddings: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + past_key_values: Cache | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + past_key_values=past_key_values, + ) + hidden_states = residual + hidden_states + hidden_states = self.mlp(hidden_states, paddings=paddings) + return hidden_states + + +class CtsmPositionalEmbedding(nn.Module): + """Generates position embedding for a given 1-d sequence.""" + + def __init__(self, config: CtsmConfig): + super().__init__() + min_timescale = config.min_timescale + max_timescale = config.max_timescale + self.min_timescale, self.max_timescale = min_timescale, max_timescale + self.embedding_dims = config.hidden_size + + num_timescales = self.embedding_dims // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1) + self.register_buffer( + "inv_timescales", + min_timescale * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment), + ) + + def forward(self, seq_length=None, position=None): + """Generates a Tensor of sinusoids with different frequencies. + + Args: + seq_length: an optional Python int defining the output sequence length. + if the `position` argument is specified. + position: [B, seq_length], optional position for each token in the + sequence, only required when the sequence is packed. + + Returns: + [B, seqlen, D] if `position` is specified, else [1, seqlen, D] + """ + if position is None and seq_length is None: + raise ValueError("Either position or seq_length must be provided") + + if position is None: + # [1, seqlen] + position = torch.arange(seq_length, dtype=torch.float32, device=self.inv_timescales.device).unsqueeze(0) + elif position.ndim != 2: + raise ValueError(f"position must be 2-dimensional, got shape {position.shape}") + + scaled_time = position.view(*position.shape, 1) * self.inv_timescales.view(1, 1, -1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) + + # Padding to ensure correct embedding dimension + signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) + return signal + + +@auto_docstring +class CtsmPreTrainedModel(PreTrainedModel): + config: CtsmConfig + base_model_prefix = "model" + _no_split_modules = ["CtsmDecoderLayer"] + main_input_name = "past_values" + input_modalities = ("time",) + _supports_sdpa = True + _can_record_outputs = { + "hidden_states": CtsmDecoderLayer, + "attentions": CtsmAttention, + } + _supports_flash_attn = True + _supports_flex_attn = True + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, CtsmAttention): + # Initialize scaling parameter + init.ones_(module.scaling) + elif isinstance(module, CtsmPositionalEmbedding): + num_timescales = module.embedding_dims // 2 + max_timescale, min_timescale = module.max_timescale, module.min_timescale + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max( + num_timescales - 1, 1 + ) + init.copy_( + module.inv_timescales, + min_timescale + * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment), + ) + if isinstance(module, CtsmModel) and getattr(module, "special_token", None) is not None: + init.normal_(module.special_token, mean=0.0, std=self.config.initializer_range) + + +@auto_docstring +class CtsmModel(CtsmPreTrainedModel): + r""" + The multi-resolution CTSM encoder. The forward pass consumes two aligned streams (a coarse low-frequency + context and a fine high-frequency context), concatenates them along the sequence dimension with an + optional learned special token, and runs a stack of rotary-attention transformer layers. Attention is + bidirectional within the coarse block and causal elsewhere. + """ + + def __init__(self, config: CtsmConfig): + super().__init__(config) + + self.config = config + self.input_ff_layer = CtsmResidualBlock( + input_dims=2 * config.patch_length, + output_dims=config.hidden_size, + hidden_dims=config.intermediate_size, + ) + self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.hidden_size) + self.layers = nn.ModuleList( + [CtsmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + if self.config.use_positional_embedding: + self.position_emb = CtsmPositionalEmbedding(config=config) + + if hasattr(self, "position_emb"): + del self.position_emb + + self.rotary_emb = CtsmRotaryEmbedding(config) + + if config.use_resolution_embeddings: + self.multi_resolution = nn.Embedding(config.num_resolutions, config.hidden_size) + + if config.use_special_token: + self.special_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + + # Initialize weights and apply final processing + self.post_init() + + def _forward_transform( + self, inputs: torch.Tensor, patched_pads: torch.Tensor + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """Input is of shape [B, N, P].""" + mu, sigma = self._ctsm_masked_mean_std(inputs, patched_pads) + sigma = torch.clamp(sigma, min=self.config.tolerance) + + # Normalize each patch + outputs = (inputs - mu[:, None, None]) / sigma[:, None, None] + outputs = torch.where( + torch.abs(inputs - self.config.pad_val) < self.config.tolerance, + torch.tensor(self.config.pad_val, dtype=outputs.dtype, device=outputs.device), + outputs, + ) + return outputs, (mu, sigma) + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + past_values_coarse: torch.Tensor | None = None, + past_values_fine: torch.Tensor | None = None, + past_values_coarse_padding: torch.LongTensor | None = None, + past_values_fine_padding: torch.LongTensor | None = None, + freq: torch.Tensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, + loc_fine: torch.Tensor | None = None, + scale_fine: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> CtsmOutput: + r""" + past_values_coarse (`torch.FloatTensor` of shape `(batch_size, coarse_length)`, *optional*): + Coarse-resolution context (e.g. hourly aggregates). Length must be a multiple of `patch_length` or + will be left-padded to one. Required when `past_key_values` is `None`. + past_values_fine (`torch.FloatTensor` of shape `(batch_size, fine_length)`): + Fine-resolution context (e.g. minute-level). In the normal / full-forward mode this is the entire + fine context; when `past_key_values` is supplied this should contain **only the new fine values** + to append — they must already be pre-normalized by the caller using `loc_fine` / `scale_fine`. + past_values_coarse_padding (`torch.LongTensor`, *optional*): + Padding mask for the coarse stream, `1.0` for padded positions and `0.0` for real values. + past_values_fine_padding (`torch.LongTensor`, *optional*): + Padding mask for the fine stream. + freq (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): + Frequency indices. Defaults to all zeros. + past_key_values (`Cache`, *optional*): + A [`Cache`] (typically a [`DynamicCache`]) holding K/V for the concatenated + `[coarse, special, fine_prefix]` sequence from a previous call. When supplied the model runs in + **incremental mode**: only the new fine patches are embedded, and their Q/K/V are added on top + of the cached K/V. `loc_fine` / `scale_fine` **must** also be supplied so the new fine values + are normalized on the same scale as the cached ones. + use_cache (`bool`, *optional*): + Whether to build and return a key/value cache in the `CtsmOutput`. Defaults to `False` unless + `past_key_values` is provided (in which case caching is always on). + cache_position (`torch.LongTensor` of shape `(num_new,)`, *optional*): + Absolute positions (in the full `[coarse, special, fine]` sequence) of the new fine patches. + Only used in incremental mode; defaults to `torch.arange(past_length, past_length + num_new)`. + loc_fine (`torch.Tensor` of shape `(batch_size,)`, *optional*): + Fine-stream mean used for stream normalization. Required in incremental mode. + scale_fine (`torch.Tensor` of shape `(batch_size,)`, *optional*): + Fine-stream standard deviation used for stream normalization. Required in incremental mode. + """ + if past_key_values is None: + return self._full_forward( + past_values_coarse=past_values_coarse, + past_values_fine=past_values_fine, + past_values_coarse_padding=past_values_coarse_padding, + past_values_fine_padding=past_values_fine_padding, + freq=freq, + use_cache=bool(use_cache), + **kwargs, + ) + return self._incremental_forward( + past_values_fine=past_values_fine, + past_values_fine_padding=past_values_fine_padding, + freq=freq, + past_key_values=past_key_values, + cache_position=cache_position, + loc_fine=loc_fine, + scale_fine=scale_fine, + **kwargs, + ) + + @staticmethod + def _prepare_4d_attention_mask( + attention_mask: torch.Tensor | None, + sequence_length: int, + dtype: torch.dtype, + device: torch.device, + is_causal: bool = True, + ) -> torch.Tensor | None: + """ + Creates 4D attention mask and combines causal and padding masks if needed. + + Args: + attention_mask: Optional tensor of shape (batch_size, seq_length) containing padding mask + sequence_length: Length of the sequence + dtype: Data type of the mask + device: Device of the mask + is_causal: Whether to apply causal masking + + Returns: + 4D attention mask of shape (batch_size, 1, seq_length, seq_length) + """ + # Get minimum value for the dtype + min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min + + # Handle padding mask + if attention_mask is not None: + # Convert 2D padding mask to 4D attention mask + attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1) + attention_mask = attention_mask * min_value + + # Create causal mask if needed + if is_causal: + causal_mask = torch.triu( + torch.ones((sequence_length, sequence_length), dtype=dtype, device=device) * min_value, + diagonal=1, + ) + causal_mask = causal_mask.view(1, 1, sequence_length, sequence_length) + + # Combine with padding mask if it exists + if attention_mask is not None: + attention_mask = torch.minimum(attention_mask, causal_mask) + else: + attention_mask = causal_mask + + return attention_mask + + @staticmethod + def _ctsm_masked_mean_std(inputs: torch.Tensor, padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Calculates mean and standard deviation of `inputs` across axis 1. + + It excludes values where `padding` is 1. + + Args: + inputs: A PyTorch tensor of shape [b, n, p]. + padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1. + + Returns: + A tuple containing the mean and standard deviation. + We return the statistics of the first patch with more than three non-padded values. + """ + + # Selecting the first patch with more than 3 unpadded values. + def _get_patch_index(arr: torch.Tensor): + indices = torch.argmax((arr >= 3).to(torch.int32), dim=1) + row_sum = (arr >= 3).to(torch.int32).sum(dim=1) + return torch.where(row_sum == 0, arr.shape[1] - 1, indices) + + pad_sum = torch.sum(1 - padding, dim=2) + patch_indices = _get_patch_index(pad_sum) + bidxs = torch.arange(inputs.shape[0]) + + arr = inputs[bidxs, patch_indices, :] + pad = padding[bidxs, patch_indices, :] + + # Create a mask where padding is 0 + mask = 1 - pad + + # Calculate the number of valid elements + num_valid_elements = torch.sum(mask, dim=1) + num_valid_elements = torch.clamp(num_valid_elements, min=1.0) + + # Calculate the masked sum and mean + masked_sum = torch.sum(arr * mask, dim=1) + masked_mean = masked_sum / num_valid_elements # [b] + + # Calculate the masked variance using centered values + masked_centered_arr = (arr - masked_mean.unsqueeze(-1)) * mask + masked_var = torch.sum(masked_centered_arr**2, dim=1) / num_valid_elements + masked_var = torch.clamp(masked_var, min=0.0) + masked_std = torch.sqrt(masked_var) + + return masked_mean, masked_std + + @staticmethod + def _ctsm_shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor: + """Shifts rows of seq based on the first 0 in each row of the mask. + + Args: + mask: mask tensor of shape [B, N] + seq: seq tensor of shape [B, N, P] + + Returns: + The shifted sequence. + """ + batch_size, num_seq, feature_dim = seq.shape + + new_mask: torch.BoolTensor = mask == 0 + + # Use argmax to find the first True value in each row + indices = new_mask.to(torch.int32).argmax(dim=1) + + # Handle rows with all zeros + indices[~new_mask.any(dim=1)] = -1 + + # Create index ranges for each sequence in the batch + idx_range = torch.arange(num_seq, device=seq.device).view(1, -1, 1).expand(batch_size, -1, feature_dim) + + # Calculate shifted indices for each element in each sequence + shifted_idx = (idx_range - indices[:, None, None]) % num_seq + + # Gather values from seq using shifted indices + shifted_seq = seq.gather(1, shifted_idx) + + return shifted_seq + + @staticmethod + def _left_pad_to_patch_boundary( + values: torch.Tensor, paddings: torch.Tensor, patch_length: int + ) -> tuple[torch.Tensor, torch.Tensor]: + rem = values.shape[1] % patch_length + if rem == 0: + return values, paddings + pad_len = patch_length - rem + values_pad = torch.zeros((values.shape[0], pad_len), device=values.device, dtype=values.dtype) + paddings_pad = torch.ones((paddings.shape[0], pad_len), device=paddings.device, dtype=paddings.dtype) + return torch.cat([values_pad, values], dim=1), torch.cat([paddings_pad, paddings], dim=1) + + @staticmethod + def _normalize_with_pad( + context: torch.Tensor, padding: torch.Tensor, tolerance: float = 1e-8 + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Stream-level normalization that matches the original CTSM reference. + + Normalizes ``context`` using the mean and standard deviation computed over the + non-padded positions (``padding == 0``) across the whole context, rather than + TimesFM's per-first-patch statistics. The normalized tensor has padded positions + zeroed out and is clamped to a safe range. + """ + valid = 1.0 - padding + count = valid.sum(dim=1, keepdim=True).clamp_min(1.0) + mu = (context * valid).sum(dim=1, keepdim=True) / count + + seq_len_f = context.new_tensor(float(context.shape[1])) + filled = torch.where(padding.to(dtype=torch.bool), mu, context) + sigma = filled.std(dim=1, keepdim=True, unbiased=False) * torch.sqrt(seq_len_f / count) + sigma = sigma.clamp_min(1e-2) + + normalized = (context - mu) / (sigma + tolerance) + normalized = normalized * valid + normalized = normalized.clamp(-1000.0, 1000.0) + return normalized, mu.squeeze(-1), sigma.squeeze(-1) + + def _patchify( + self, past_values: torch.Tensor, past_values_padding: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Patchify an already stream-normalized stream and project through the input tokenizer.""" + bsize = past_values.shape[0] + patched_inputs = past_values.view(bsize, -1, self.config.patch_length) + patched_pads = past_values_padding.view(bsize, -1, self.config.patch_length) + + patched_inputs = patched_inputs * (1.0 - patched_pads) + concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1) + embeddings = self.input_ff_layer(concat_inputs) + patch_padding = torch.min(patched_pads, dim=-1)[0] + return embeddings, patch_padding + + def _build_attention_mask( + self, + patch_padding: torch.Tensor, + num_coarse_patches: int, + dtype: torch.dtype, + ) -> torch.Tensor: + """Reuse TimesFM's padding+causal 4D mask, then open the coarse-coarse block to bidirectional.""" + attention_mask = self._prepare_4d_attention_mask( + attention_mask=patch_padding, + sequence_length=patch_padding.shape[1], + dtype=dtype, + device=patch_padding.device, + is_causal=True, + ) + if num_coarse_patches > 0: + attention_mask[..., :num_coarse_patches, :num_coarse_patches] = 0.0 + return attention_mask + + def _build_incremental_attention_mask( + self, bsize: int, num_new: int, past_length: int, dtype: torch.dtype, device: torch.device + ) -> torch.Tensor: + """Mask for the incremental (cached) path: new fine Qs attend to all cached K/V plus causal within the new block.""" + min_value = torch.finfo(dtype).min + mask = torch.zeros((bsize, 1, num_new, past_length + num_new), dtype=dtype, device=device) + if num_new > 1: + causal_new = torch.triu(torch.full((num_new, num_new), min_value, dtype=dtype, device=device), diagonal=1) + mask[:, :, :, past_length:] = causal_new + return mask + + def _full_forward( + self, + past_values_coarse: torch.Tensor, + past_values_fine: torch.Tensor, + past_values_coarse_padding: torch.LongTensor | None, + past_values_fine_padding: torch.LongTensor | None, + freq: torch.Tensor | None, + use_cache: bool, + **kwargs: Unpack[TransformersKwargs], + ) -> CtsmOutput: + if past_values_coarse_padding is None: + past_values_coarse_padding = torch.zeros_like(past_values_coarse) + if past_values_fine_padding is None: + past_values_fine_padding = torch.zeros_like(past_values_fine) + past_values_coarse_padding = past_values_coarse_padding.to(past_values_coarse.dtype) + past_values_fine_padding = past_values_fine_padding.to(past_values_fine.dtype) + + patch_length = self.config.patch_length + past_values_coarse, past_values_coarse_padding = self._left_pad_to_patch_boundary( + past_values_coarse, past_values_coarse_padding, patch_length + ) + past_values_fine, past_values_fine_padding = self._left_pad_to_patch_boundary( + past_values_fine, past_values_fine_padding, patch_length + ) + + coarse_normalized, loc_coarse, scale_coarse = self._normalize_with_pad( + past_values_coarse, past_values_coarse_padding, tolerance=self.config.tolerance + ) + fine_normalized, loc_fine, scale_fine = self._normalize_with_pad( + past_values_fine, past_values_fine_padding, tolerance=self.config.tolerance + ) + + coarse_embeddings, coarse_patch_padding = self._patchify(coarse_normalized, past_values_coarse_padding) + fine_embeddings, fine_patch_padding = self._patchify(fine_normalized, past_values_fine_padding) + + bsize, num_coarse_patches, hidden_size = coarse_embeddings.shape + num_fine_patches = fine_embeddings.shape[1] + device = coarse_embeddings.device + dtype = coarse_embeddings.dtype + + if self.config.use_special_token: + special = self.special_token.to(device=device, dtype=dtype).expand(bsize, 1, hidden_size) + special_padding = torch.zeros(bsize, 1, device=device, dtype=coarse_patch_padding.dtype) + model_input = torch.cat([coarse_embeddings, special, fine_embeddings], dim=1) + patch_padding = torch.cat([coarse_patch_padding, special_padding, fine_patch_padding], dim=1) + num_special = 1 + else: + model_input = torch.cat([coarse_embeddings, fine_embeddings], dim=1) + patch_padding = torch.cat([coarse_patch_padding, fine_patch_padding], dim=1) + num_special = 0 + + if self.config.use_resolution_embeddings: + mr_coarse = torch.zeros(num_coarse_patches, dtype=torch.long, device=device) + mr_special = torch.full((num_special,), 1, dtype=torch.long, device=device) + mr_fine = torch.full((num_fine_patches,), 2, dtype=torch.long, device=device) + mr_idx = torch.cat([mr_coarse, mr_special, mr_fine], dim=0).unsqueeze(0).expand(bsize, -1) + model_input = model_input + self.multi_resolution(mr_idx) + + if freq is None: + freq = torch.zeros((bsize, 1), dtype=torch.long, device=device) + else: + freq = freq.to(device=device, dtype=torch.long) + model_input = model_input + self.freq_emb(freq) + + attention_mask = self._build_attention_mask(patch_padding, num_coarse_patches, model_input.dtype) + position_ids = ( + torch.arange(model_input.shape[1], device=device, dtype=torch.long).unsqueeze(0).expand(bsize, -1) + ) + position_embeddings = self.rotary_emb(model_input, position_ids) + + past_key_values = DynamicCache(config=self.config) if use_cache else None + + hidden_states = model_input + for layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = layer( + hidden_states, + attention_mask=attention_mask, + paddings=patch_padding, + position_embeddings=position_embeddings, + past_key_values=past_key_values, + **kwargs, + ) + + return CtsmOutput( + last_hidden_state=hidden_states, + loc=loc_fine, + scale=scale_fine, + loc_coarse=loc_coarse, + scale_coarse=scale_coarse, + num_coarse_patches=num_coarse_patches + num_special, + num_fine_patches=num_fine_patches, + past_key_values=past_key_values, + ) + + def _incremental_forward( + self, + past_values_fine: torch.Tensor, + past_values_fine_padding: torch.LongTensor | None, + freq: torch.Tensor | None, + past_key_values: Cache, + cache_position: torch.LongTensor | None, + loc_fine: torch.Tensor | None, + scale_fine: torch.Tensor | None, + **kwargs: Unpack[TransformersKwargs], + ) -> CtsmOutput: + if loc_fine is None or scale_fine is None: + raise ValueError( + "`loc_fine` and `scale_fine` must be supplied together with `past_key_values` so that the new fine " + "values are normalized on the same scale as the cached ones." + ) + if past_values_fine.shape[1] % self.config.patch_length != 0: + raise ValueError( + f"In incremental mode `past_values_fine` length must be a multiple of `patch_length=" + f"{self.config.patch_length}`; got {past_values_fine.shape[1]}." + ) + + if past_values_fine_padding is None: + past_values_fine_padding = torch.zeros_like(past_values_fine) + past_values_fine_padding = past_values_fine_padding.to(past_values_fine.dtype) + + tol = self.config.tolerance + fine_normalized = (past_values_fine - loc_fine.unsqueeze(-1)) / (scale_fine.unsqueeze(-1) + tol) + fine_normalized = fine_normalized * (1.0 - past_values_fine_padding) + fine_normalized = fine_normalized.clamp(-1000.0, 1000.0) + + new_embeddings, new_patch_padding = self._patchify(fine_normalized, past_values_fine_padding) + bsize, num_new, _ = new_embeddings.shape + device = new_embeddings.device + dtype = new_embeddings.dtype + + if self.config.use_resolution_embeddings: + mr_idx = torch.full((bsize, num_new), 2, dtype=torch.long, device=device) + new_embeddings = new_embeddings + self.multi_resolution(mr_idx) + + if freq is None: + freq = torch.zeros((bsize, 1), dtype=torch.long, device=device) + else: + freq = freq.to(device=device, dtype=torch.long) + new_embeddings = new_embeddings + self.freq_emb(freq) + + past_length = past_key_values.get_seq_length() + if cache_position is None: + cache_position = torch.arange(past_length, past_length + num_new, dtype=torch.long, device=device) + position_ids = cache_position.unsqueeze(0).expand(bsize, -1) + position_embeddings = self.rotary_emb(new_embeddings, position_ids) + + attention_mask = self._build_incremental_attention_mask(bsize, num_new, past_length, dtype, device) + + hidden_states = new_embeddings + for layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = layer( + hidden_states, + attention_mask=attention_mask, + paddings=new_patch_padding, + position_embeddings=position_embeddings, + past_key_values=past_key_values, + **kwargs, + ) + + return CtsmOutput( + last_hidden_state=hidden_states, + loc=loc_fine, + scale=scale_fine, + num_fine_patches=num_new, + past_key_values=past_key_values, + ) + + +class CtsmModelForPrediction(CtsmPreTrainedModel): + """CTSM model with a multi-resolution prediction head and autoregressive multi-resolution decoding. + + For horizons that require autoregressive decoding (``horizon_len > config.horizon_length``) the + prediction class reuses a key/value cache across AR steps: the first step runs the full forward + and populates a [`DynamicCache`], subsequent steps feed only the newly-appended fine patches + through the stack and attend to the cached K/V for every earlier position. Two caveats, matching + how a KV cache is made to fit CTSM's architecture: + + * Stream-level normalization statistics (``loc_fine``, ``scale_fine``) are frozen to the values + computed on the first step. This is a small approximation: in the untracked reference, + statistics are recomputed after each prediction is appended; in practice the drift is small + when forecasts stay in-distribution. + * If an AR step would grow the coarse block (a new coarse patch is formed once every + ``patch_length * agg_factor / output_patch_len`` steps, i.e. ~every 15 steps at the defaults), + the cache is discarded and a full forward is run, rebuilding the cache. + """ + + def __init__(self, config: CtsmConfig): + super().__init__(config) + + self.config = config + self.context_len = config.context_length + self.horizon_len = config.horizon_length + + self.model = CtsmModel(config) + num_outputs = 1 + len(config.quantiles) + self.horizon_ff_layer = CtsmResidualBlock( + input_dims=config.hidden_size, + output_dims=config.horizon_length * num_outputs, + hidden_dims=config.intermediate_size, + ) + + # Initialize weights and apply final processing + self.post_init() + + def _preprocess( + self, inputs: Sequence[torch.Tensor], freq: Sequence[int] | None = None, context_len: int | None = None + ) -> tuple[torch.Tensor, ...]: + """Pad/truncate input time series to `context_len` and build a padding mask. + + Args: + inputs: A list of 1d Tensors. Each Tensor is the context time series of a single forecast task. + freq: Optional list of frequencies (returned as a tensor when provided). + context_len: Optional context length override (defaults to `self.context_len`). + + Returns: + Tuple of (padded_inputs, padding_mask) and optionally a freq tensor. + """ + if context_len is None: + context_len = self.context_len + + input_ts, input_padding = [], [] + + for ts in inputs: + input_len = ts.shape[0] + padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device) + if input_len < context_len: + num_front_pad = context_len - input_len + ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0) + elif input_len > context_len: + ts = ts[-context_len:] + padding = padding[-(context_len + self.horizon_len) :] + + input_ts.append(ts) + input_padding.append(padding) + + result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) + if freq is not None: + result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),) + return result + + def _postprocess_output( + self, model_output: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor] + ) -> torch.Tensor: + """Postprocess output of stacked transformer.""" + + # B x N x (H.Q) + output_ts = self.horizon_ff_layer(model_output) + + # Reshape using view + b, n, _ = output_ts.shape + output_ts = output_ts.view(b, n, self.config.horizon_length, len(self.config.quantiles) + 1) + + mu, sigma = stats + return output_ts * sigma[:, None, None, None] + mu[:, None, None, None] + + def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + losses = [] + for i, q in enumerate(self.config.quantiles): + errors = targets - predictions[..., i] + loss = torch.max((q - 1) * errors, q * errors) + losses.append(loss.mean()) + return torch.stack(losses).mean() + + @can_return_tuple + @auto_docstring + def forward( + self, + past_values: Sequence[torch.Tensor] | Sequence[tuple[torch.Tensor, torch.Tensor]], + future_values: torch.Tensor | None = None, + horizon_len: int | None = None, + freq: Sequence[int] | torch.Tensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> CtsmOutputForPrediction: + r""" + past_values (`Sequence[torch.Tensor]`): + Either a list of 1-D fine-resolution tensors (the coarse stream is derived by mean-aggregating over + `agg_factor` consecutive points) or a list of `(coarse, fine)` pairs if both streams are provided. + future_values (`torch.Tensor`, *optional*): + Optional fine-resolution ground truth used to compute the loss. + horizon_len (`int`, *optional*): + Number of fine-resolution steps to forecast. Defaults to `config.horizon_length`. Values larger than + `config.horizon_length` trigger autoregressive decoding. + freq (`Sequence[int]` or `torch.Tensor`, *optional*): + Frequency indices. Defaults to zeros. + use_cache (`bool`, *optional*): + Whether to use a key/value cache across autoregressive decode steps. Defaults to `True` when + `horizon_len > config.horizon_length` (i.e. when AR decoding is needed) and `False` otherwise. + Set to `False` to force a full recompute at every AR step (matches the original reference + behaviour; slower but avoids the stream-stats-freezing approximation). + """ + device = self.horizon_ff_layer.input_layer.weight.device + horizon_len = horizon_len or self.config.horizon_length + if horizon_len <= 0: + raise ValueError("horizon_len must be positive") + + output_patch_len = self.config.horizon_length + num_decode_patches = (horizon_len + output_patch_len - 1) // output_patch_len + + coarse, coarse_pad, fine, fine_pad = self._prepare_context(past_values, device=device) + bsize = coarse.shape[0] + + if freq is None: + freq_tensor = torch.zeros((bsize, 1), dtype=torch.long, device=device) + else: + freq_tensor = torch.as_tensor( + list(freq) if not isinstance(freq, torch.Tensor) else freq, dtype=torch.long, device=device + ).view(bsize, 1) + + mean_chunks: list[torch.Tensor] = [] + quant_chunks: list[torch.Tensor] = [] + remaining = horizon_len + last_outputs: CtsmOutput | None = None + max_fine = self.config.context_length + max_coarse = self.config.context_length + agg = self.config.agg_factor + new_fine_patches = self.config.horizon_length // self.config.patch_length + + past_key_values: Cache | None = None + frozen_loc_fine: torch.Tensor | None = None + frozen_scale_fine: torch.Tensor | None = None + coarse_buffer = torch.zeros((bsize, 0), dtype=torch.float32, device=device) + + if use_cache is None: + use_cache = num_decode_patches > 1 + pending_new_fine: torch.Tensor | None = None + + for step_idx in range(num_decode_patches): + if past_key_values is None: + # First step (or after cache invalidation): full forward. The coarse block in the cache + # stays frozen at the initial state — only the fine block grows via subsequent incremental + # steps — which matches how KV caches work for append-only sequences. + mean_patch, quant_patch, last_outputs = self._decode_step_full( + past_values_coarse=coarse, + past_values_fine=fine, + past_values_coarse_padding=coarse_pad, + past_values_fine_padding=fine_pad, + freq=freq_tensor, + use_cache=use_cache, + **kwargs, + ) + past_key_values = last_outputs.past_key_values + frozen_loc_fine = last_outputs.loc + frozen_scale_fine = last_outputs.scale + else: + # Incremental: only the fine values newly appended last step go through the stack. + mean_patch, quant_patch, last_outputs = self._decode_step_incremental( + new_fine_values=pending_new_fine, + freq=freq_tensor, + past_key_values=past_key_values, + loc_fine=frozen_loc_fine, + scale_fine=frozen_scale_fine, + **kwargs, + ) + + take = min(remaining, output_patch_len) + mean_chunks.append(mean_patch[:, :take]) + quant_chunks.append(quant_patch[:, :take, :]) + remaining -= take + if remaining <= 0: + break + + new_fine = mean_patch[:, :output_patch_len] + pending_new_fine = new_fine + + # Track the raw contexts so the next full-forward (initial step or after cache + # invalidation) sees the right state. Mirrors the reference AR loop. + fine = torch.cat([fine, new_fine], dim=1) + fine_pad = torch.cat( + [fine_pad, torch.zeros((bsize, output_patch_len), device=device, dtype=fine_pad.dtype)], dim=1 + ) + if fine.shape[1] > max_fine: + fine = fine[:, -max_fine:] + fine_pad = fine_pad[:, -max_fine:] + + coarse_buffer = torch.cat([coarse_buffer, new_fine], dim=1) + full_blocks = coarse_buffer.shape[1] // agg + if full_blocks > 0: + blocks = coarse_buffer[:, : full_blocks * agg].view(bsize, full_blocks, agg).mean(dim=2) + coarse_buffer = coarse_buffer[:, full_blocks * agg :] + coarse = torch.cat([coarse, blocks], dim=1) + coarse_pad = torch.cat( + [coarse_pad, torch.zeros((bsize, full_blocks), device=device, dtype=coarse_pad.dtype)], dim=1 + ) + if coarse.shape[1] > max_coarse: + coarse = coarse[:, -max_coarse:] + coarse_pad = coarse_pad[:, -max_coarse:] + + if past_key_values is not None: + projected_len = past_key_values.get_seq_length() + new_fine_patches + if projected_len >= self.config.max_position_embeddings: + past_key_values = None + pending_new_fine = None + + mean_predictions = torch.cat(mean_chunks, dim=1)[:, :horizon_len] + full_predictions = torch.cat( + [torch.cat(mean_chunks, dim=1)[:, :horizon_len, None], torch.cat(quant_chunks, dim=1)[:, :horizon_len, :]], + dim=-1, + ) + + loss = None + if future_values is not None: + target_len = min(future_values.shape[1], mean_predictions.shape[1]) + mse_loss = F.mse_loss(mean_predictions[:, :target_len], future_values[:, :target_len]) + quantile_loss = self._quantile_loss(full_predictions[:, :target_len, 1:], future_values[:, :target_len]) + loss = mse_loss + quantile_loss + + return CtsmOutputForPrediction( + last_hidden_state=last_outputs.last_hidden_state if last_outputs is not None else None, + hidden_states=last_outputs.hidden_states if last_outputs is not None else None, + attentions=last_outputs.attentions if last_outputs is not None else None, + mean_predictions=mean_predictions, + full_predictions=full_predictions, + loss=loss, + ) + + @staticmethod + def _ctsm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]: + """Calculates the moving average using PyTorch's convolution function.""" + # Pad with zeros to handle initial window positions + arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) + # Create a convolution kernel + kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size + # Apply convolution to calculate the moving average + smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze() + return [smoothed_arr, arr - smoothed_arr] + + @staticmethod + def _build_multi_resolution( + series: torch.Tensor, agg_factor: int, coarse_len: int, fine_len: int + ) -> tuple[torch.Tensor, torch.Tensor]: + """Build (coarse, fine) contexts from a 1-D fine-resolution series. + + Coarse is the mean of the last `coarse_len * agg_factor` fine samples, aligned to block boundaries. + Fine is the last `fine_len` samples. + """ + series = series.to(torch.float32).reshape(-1) + needed = coarse_len * agg_factor + raw = series[-needed:] + remainder = raw.shape[0] % agg_factor + if remainder: + raw = raw[remainder:] + if raw.numel() == 0: + coarse = series.new_empty((0,), dtype=torch.float32) + else: + coarse = raw.reshape(-1, agg_factor).mean(dim=1) + if coarse.shape[0] > coarse_len: + coarse = coarse[-coarse_len:] + fine = series[-fine_len:].to(torch.float32) + return coarse, fine + + def _prepare_context( + self, + past_values: Sequence[torch.Tensor] | Sequence[tuple[torch.Tensor, torch.Tensor]], + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + coarse_len = self.config.context_length + fine_len = self.config.context_length + agg = self.config.agg_factor + + coarse_batch = torch.zeros((len(past_values), coarse_len), dtype=torch.float32, device=device) + coarse_pad = torch.zeros_like(coarse_batch) + fine_batch = torch.zeros((len(past_values), fine_len), dtype=torch.float32, device=device) + fine_pad = torch.zeros_like(fine_batch) + + for i, item in enumerate(past_values): + if isinstance(item, (tuple, list)) and len(item) == 2: + coarse, fine = item + coarse = torch.as_tensor(coarse, dtype=torch.float32, device=device).reshape(-1) + fine = torch.as_tensor(fine, dtype=torch.float32, device=device).reshape(-1) + else: + series = torch.as_tensor(item, dtype=torch.float32, device=device).reshape(-1) + coarse, fine = self._build_multi_resolution(series, agg, coarse_len, fine_len) + + c_n = coarse.shape[0] + if c_n >= coarse_len: + coarse_batch[i] = coarse[-coarse_len:] + elif c_n > 0: + coarse_batch[i, coarse_len - c_n :] = coarse + coarse_pad[i, : coarse_len - c_n] = 1.0 + else: + coarse_pad[i] = 1.0 + + f_n = fine.shape[0] + if f_n >= fine_len: + fine_batch[i] = fine[-fine_len:] + elif f_n > 0: + fine_batch[i, fine_len - f_n :] = fine + fine_pad[i, : fine_len - f_n] = 1.0 + else: + fine_pad[i] = 1.0 + + return coarse_batch, coarse_pad, fine_batch, fine_pad + + def _project_last_fine(self, outputs: CtsmOutput, last_position: int) -> tuple[torch.Tensor, torch.Tensor]: + """Project the hidden state at `last_position` through the horizon head and denormalize.""" + last_hidden = outputs.last_hidden_state[:, last_position : last_position + 1, :] + head = self.horizon_ff_layer(last_hidden) + bsize = head.shape[0] + num_outputs = 1 + len(self.config.quantiles) + head = head.view(bsize, self.config.horizon_length, num_outputs) + + loc = outputs.loc[:, None, None] + scale = outputs.scale[:, None, None] + mean_patch = head[..., 0] * scale[..., 0] + loc[..., 0] + quant_patch = head[..., 1:] * scale + loc + mean_patch = torch.nan_to_num(mean_patch, nan=0.0, posinf=0.0, neginf=0.0) + quant_patch = torch.nan_to_num(quant_patch, nan=0.0, posinf=0.0, neginf=0.0) + return mean_patch, quant_patch + + def _decode_step_full( + self, + past_values_coarse: torch.Tensor, + past_values_fine: torch.Tensor, + past_values_coarse_padding: torch.Tensor, + past_values_fine_padding: torch.Tensor, + freq: torch.Tensor, + use_cache: bool, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor, CtsmOutput]: + """Full forward through the model. If `use_cache`, the returned outputs carry a fresh cache.""" + outputs: CtsmOutput = self.model( + past_values_coarse=past_values_coarse, + past_values_fine=past_values_fine, + past_values_coarse_padding=past_values_coarse_padding, + past_values_fine_padding=past_values_fine_padding, + freq=freq, + use_cache=use_cache, + **kwargs, + ) + mean_patch, quant_patch = self._project_last_fine(outputs, outputs.last_hidden_state.shape[1] - 1) + return mean_patch, quant_patch, outputs + + def _decode_step_incremental( + self, + new_fine_values: torch.Tensor, + freq: torch.Tensor, + past_key_values: Cache, + loc_fine: torch.Tensor, + scale_fine: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor, CtsmOutput]: + """Append `new_fine_values` to the cached state and run only the new positions through the stack.""" + outputs: CtsmOutput = self.model( + past_values_fine=new_fine_values, + freq=freq, + past_key_values=past_key_values, + loc_fine=loc_fine, + scale_fine=scale_fine, + **kwargs, + ) + mean_patch, quant_patch = self._project_last_fine(outputs, outputs.last_hidden_state.shape[1] - 1) + return mean_patch, quant_patch, outputs + + +__all__ = ["CtsmModel", "CtsmModelForPrediction", "CtsmPreTrainedModel"] diff --git a/src/transformers/models/ctsm/modular_ctsm.py b/src/transformers/models/ctsm/modular_ctsm.py new file mode 100644 index 000000000000..e56fe16403c5 --- /dev/null +++ b/src/transformers/models/ctsm/modular_ctsm.py @@ -0,0 +1,941 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Cisco Time Series Model (CTSM).""" + +from collections.abc import Callable, Sequence +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F +from huggingface_hub.dataclasses import strict + +from ... import initialization as init +from ...cache_utils import Cache, DynamicCache +from ...modeling_rope_utils import RopeParameters +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.generic import merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from ..phi4_multimodal.modeling_phi4_multimodal import simple_eager_attention_forward +from ..timesfm.configuration_timesfm import TimesFmConfig +from ..timesfm.modeling_timesfm import ( + TimesFmAttention, + TimesFmDecoderLayer, + TimesFmModel, + TimesFmModelForPrediction, + TimesFmOutput, + TimesFmOutputForPrediction, + TimesFmPreTrainedModel, + TimesFmResidualBlock, # re-exported as CtsmResidualBlock in the generated file +) +from ..timesfm2_5.modeling_timesfm2_5 import ( + TimesFm2_5RotaryEmbedding, + apply_rotary_pos_emb, +) + + +logger = logging.get_logger(__name__) + + +@auto_docstring(checkpoint="cisco-ai/cisco-time-series-model-1.0") +@strict +class CtsmConfig(TimesFmConfig): + r""" + patch_length (`int`, *optional*, defaults to 32): + Length of one patch in the input sequence for each resolution stream. + context_length (`int`, *optional*, defaults to 512): + Length of the input context for each resolution stream. + horizon_length (`int`, *optional*, defaults to 128): + Length of the prediction horizon produced per autoregressive step. + freq_size (`int`, *optional*, defaults to 3): + Number of frequency embeddings. + tolerance (`float`, *optional*, defaults to 1e-06): + Numerical tolerance used in normalization. + pad_val (`float`, *optional*, defaults to 1123581321.0): + Sentinel value marking padded positions in the input series. + num_hidden_layers (`int`, *optional*, defaults to 25): + Number of decoder layers. + quantiles (`list[float]`, *optional*, defaults to 15 values between 0.01 and 0.99): + Quantile levels predicted by the model. + use_positional_embedding (`bool`, *optional*, defaults to `False`): + CTSM uses rotary position embeddings and does not add sinusoidal positional embeddings. + use_resolution_embeddings (`bool`, *optional*, defaults to `True`): + Whether to add a learned embedding per resolution bucket (coarse / special / fine). + use_special_token (`bool`, *optional*, defaults to `True`): + Whether to insert a learned special token between the coarse and fine streams. + num_resolutions (`int`, *optional*, defaults to 3): + Number of resolution embeddings (coarse, special token, fine). + agg_factor (`int`, *optional*, defaults to 60): + Aggregation factor between fine and coarse resolutions (e.g. 60 minutes -> 1 hour). + max_position_embeddings (`int`, *optional*, defaults to 1025): + Maximum number of patches in the concatenated sequence (coarse + special + fine). + rope_parameters (`dict`, *optional*): + Rotary position embedding parameters. Defaults to `{"rope_type": "default", "rope_theta": 10000.0}`. + + Example: + + ```python + >>> from transformers import CtsmConfig, CtsmModelForPrediction + + >>> configuration = CtsmConfig() + >>> model = CtsmModelForPrediction(configuration) + >>> configuration = model.config + ``` + """ + + model_type = "ctsm" + + num_hidden_layers: int = 25 + context_length: int = 512 + quantiles: list[float] | tuple[float, ...] = ( + 0.01, + 0.05, + 0.1, + 0.2, + 0.25, + 0.3, + 0.4, + 0.5, + 0.6, + 0.7, + 0.75, + 0.8, + 0.9, + 0.95, + 0.99, + ) + use_positional_embedding: bool = False + use_resolution_embeddings: bool = True + use_special_token: bool = True + num_resolutions: int = 3 + agg_factor: int = 60 + max_position_embeddings: int = 1025 + rope_parameters: RopeParameters | dict | None = None + + min_timescale = AttributeError() + max_timescale = AttributeError() + + +@dataclass +@auto_docstring +class CtsmOutput(TimesFmOutput): + r""" + loc (`torch.Tensor` of shape `(batch_size,)`): + Stream-level mean used to normalize the fine-resolution context, reused to rescale the final forecast. + scale (`torch.Tensor` of shape `(batch_size,)`): + Stream-level standard deviation of the fine-resolution context. + loc_coarse (`torch.Tensor` of shape `(batch_size,)`): + Stream-level mean used to normalize the coarse-resolution context. + scale_coarse (`torch.Tensor` of shape `(batch_size,)`): + Stream-level standard deviation of the coarse-resolution context. + num_coarse_patches (`int`): + Number of patches (including the optional special token) preceding the fine-resolution block. + num_fine_patches (`int`): + Number of patches in the fine-resolution block of the concatenated sequence. + past_key_values (`Cache`, *optional*): + Key/value cache for the concatenated `[coarse, special, fine]` sequence. Populated when the + caller passes `use_cache=True` (and re-used across autoregressive decode steps). Typically only + the long-horizon AR loop in [`CtsmModelForPrediction`] needs this. + """ + + loc_coarse: torch.Tensor | None = None + scale_coarse: torch.Tensor | None = None + num_coarse_patches: int | None = None + num_fine_patches: int | None = None + past_key_values: Cache | None = None + + +@dataclass +@auto_docstring +class CtsmOutputForPrediction(TimesFmOutputForPrediction): + r""" + mean_predictions (`torch.Tensor` of shape `(batch_size, horizon_length)`): + Point forecasts over the fine-resolution horizon. + full_predictions (`torch.Tensor` of shape `(batch_size, horizon_length, 1 + num_quantiles)`): + Concatenation of the mean prediction and the quantile predictions along the last axis. + loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `future_values` is provided): + Training loss combining MSE of the mean forecast and quantile loss when fine-resolution targets are supplied. + """ + + pass + + +class CtsmResidualBlock(TimesFmResidualBlock): + pass + + +class CtsmRotaryEmbedding(TimesFm2_5RotaryEmbedding): + pass + + +class CtsmAttention(TimesFmAttention): + """TimesFM 2.0 style attention with learnable per-dimension Q scaling and rotary position embeddings. + + Supports an optional `past_key_values` cache so that, during long-horizon autoregressive decoding, + each step only needs to compute K/V for the newly-appended fine patches and attends to the + previously-cached K/V for every earlier position. + """ + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + past_key_values: Cache | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + query_states = self._scale_query(query_states) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, simple_eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=1.0, + **kwargs, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class CtsmDecoderLayer(TimesFmDecoderLayer): + """CTSM transformer block: attention with RoPE followed by TimesFM 2.0 MLP with padding masking.""" + + def __init__(self, config: CtsmConfig, layer_idx: int): + super().__init__(config, layer_idx=layer_idx) + self.self_attn = CtsmAttention(config, layer_idx=layer_idx) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + paddings: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + past_key_values: Cache | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + past_key_values=past_key_values, + ) + hidden_states = residual + hidden_states + hidden_states = self.mlp(hidden_states, paddings=paddings) + return hidden_states + + +@auto_docstring +class CtsmPreTrainedModel(TimesFmPreTrainedModel): + config: CtsmConfig + base_model_prefix = "model" + _no_split_modules = ["CtsmDecoderLayer"] + _supports_flash_attn = True + _supports_flex_attn = True + _can_record_outputs = { + "hidden_states": CtsmDecoderLayer, + "attentions": CtsmAttention, + } + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, CtsmModel) and getattr(module, "special_token", None) is not None: + init.normal_(module.special_token, mean=0.0, std=self.config.initializer_range) + + +class CtsmModel(TimesFmModel): + r""" + The multi-resolution CTSM encoder. The forward pass consumes two aligned streams (a coarse low-frequency + context and a fine high-frequency context), concatenates them along the sequence dimension with an + optional learned special token, and runs a stack of rotary-attention transformer layers. Attention is + bidirectional within the coarse block and causal elsewhere. + """ + + def __init__(self, config: CtsmConfig): + super().__init__(config) + + if hasattr(self, "position_emb"): + del self.position_emb + + self.rotary_emb = CtsmRotaryEmbedding(config) + + if config.use_resolution_embeddings: + self.multi_resolution = nn.Embedding(config.num_resolutions, config.hidden_size) + + if config.use_special_token: + self.special_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + + self.post_init() + + @staticmethod + def _left_pad_to_patch_boundary( + values: torch.Tensor, paddings: torch.Tensor, patch_length: int + ) -> tuple[torch.Tensor, torch.Tensor]: + rem = values.shape[1] % patch_length + if rem == 0: + return values, paddings + pad_len = patch_length - rem + values_pad = torch.zeros((values.shape[0], pad_len), device=values.device, dtype=values.dtype) + paddings_pad = torch.ones((paddings.shape[0], pad_len), device=paddings.device, dtype=paddings.dtype) + return torch.cat([values_pad, values], dim=1), torch.cat([paddings_pad, paddings], dim=1) + + @staticmethod + def _normalize_with_pad( + context: torch.Tensor, padding: torch.Tensor, tolerance: float = 1e-8 + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Stream-level normalization that matches the original CTSM reference. + + Normalizes ``context`` using the mean and standard deviation computed over the + non-padded positions (``padding == 0``) across the whole context, rather than + TimesFM's per-first-patch statistics. The normalized tensor has padded positions + zeroed out and is clamped to a safe range. + """ + valid = 1.0 - padding + count = valid.sum(dim=1, keepdim=True).clamp_min(1.0) + mu = (context * valid).sum(dim=1, keepdim=True) / count + + seq_len_f = context.new_tensor(float(context.shape[1])) + filled = torch.where(padding.to(dtype=torch.bool), mu, context) + sigma = filled.std(dim=1, keepdim=True, unbiased=False) * torch.sqrt(seq_len_f / count) + sigma = sigma.clamp_min(1e-2) + + normalized = (context - mu) / (sigma + tolerance) + normalized = normalized * valid + normalized = normalized.clamp(-1000.0, 1000.0) + return normalized, mu.squeeze(-1), sigma.squeeze(-1) + + def _patchify( + self, past_values: torch.Tensor, past_values_padding: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Patchify an already stream-normalized stream and project through the input tokenizer.""" + bsize = past_values.shape[0] + patched_inputs = past_values.view(bsize, -1, self.config.patch_length) + patched_pads = past_values_padding.view(bsize, -1, self.config.patch_length) + + patched_inputs = patched_inputs * (1.0 - patched_pads) + concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1) + embeddings = self.input_ff_layer(concat_inputs) + patch_padding = torch.min(patched_pads, dim=-1)[0] + return embeddings, patch_padding + + def _build_attention_mask( + self, + patch_padding: torch.Tensor, + num_coarse_patches: int, + dtype: torch.dtype, + ) -> torch.Tensor: + """Reuse TimesFM's padding+causal 4D mask, then open the coarse-coarse block to bidirectional.""" + attention_mask = self._prepare_4d_attention_mask( + attention_mask=patch_padding, + sequence_length=patch_padding.shape[1], + dtype=dtype, + device=patch_padding.device, + is_causal=True, + ) + if num_coarse_patches > 0: + attention_mask[..., :num_coarse_patches, :num_coarse_patches] = 0.0 + return attention_mask + + def _build_incremental_attention_mask( + self, bsize: int, num_new: int, past_length: int, dtype: torch.dtype, device: torch.device + ) -> torch.Tensor: + """Mask for the incremental (cached) path: new fine Qs attend to all cached K/V plus causal within the new block.""" + min_value = torch.finfo(dtype).min + mask = torch.zeros((bsize, 1, num_new, past_length + num_new), dtype=dtype, device=device) + if num_new > 1: + causal_new = torch.triu(torch.full((num_new, num_new), min_value, dtype=dtype, device=device), diagonal=1) + mask[:, :, :, past_length:] = causal_new + return mask + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + past_values_coarse: torch.Tensor | None = None, + past_values_fine: torch.Tensor | None = None, + past_values_coarse_padding: torch.LongTensor | None = None, + past_values_fine_padding: torch.LongTensor | None = None, + freq: torch.Tensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, + loc_fine: torch.Tensor | None = None, + scale_fine: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> CtsmOutput: + r""" + past_values_coarse (`torch.FloatTensor` of shape `(batch_size, coarse_length)`, *optional*): + Coarse-resolution context (e.g. hourly aggregates). Length must be a multiple of `patch_length` or + will be left-padded to one. Required when `past_key_values` is `None`. + past_values_fine (`torch.FloatTensor` of shape `(batch_size, fine_length)`): + Fine-resolution context (e.g. minute-level). In the normal / full-forward mode this is the entire + fine context; when `past_key_values` is supplied this should contain **only the new fine values** + to append — they must already be pre-normalized by the caller using `loc_fine` / `scale_fine`. + past_values_coarse_padding (`torch.LongTensor`, *optional*): + Padding mask for the coarse stream, `1.0` for padded positions and `0.0` for real values. + past_values_fine_padding (`torch.LongTensor`, *optional*): + Padding mask for the fine stream. + freq (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): + Frequency indices. Defaults to all zeros. + past_key_values (`Cache`, *optional*): + A [`Cache`] (typically a [`DynamicCache`]) holding K/V for the concatenated + `[coarse, special, fine_prefix]` sequence from a previous call. When supplied the model runs in + **incremental mode**: only the new fine patches are embedded, and their Q/K/V are added on top + of the cached K/V. `loc_fine` / `scale_fine` **must** also be supplied so the new fine values + are normalized on the same scale as the cached ones. + use_cache (`bool`, *optional*): + Whether to build and return a key/value cache in the `CtsmOutput`. Defaults to `False` unless + `past_key_values` is provided (in which case caching is always on). + cache_position (`torch.LongTensor` of shape `(num_new,)`, *optional*): + Absolute positions (in the full `[coarse, special, fine]` sequence) of the new fine patches. + Only used in incremental mode; defaults to `torch.arange(past_length, past_length + num_new)`. + loc_fine (`torch.Tensor` of shape `(batch_size,)`, *optional*): + Fine-stream mean used for stream normalization. Required in incremental mode. + scale_fine (`torch.Tensor` of shape `(batch_size,)`, *optional*): + Fine-stream standard deviation used for stream normalization. Required in incremental mode. + """ + if past_key_values is None: + return self._full_forward( + past_values_coarse=past_values_coarse, + past_values_fine=past_values_fine, + past_values_coarse_padding=past_values_coarse_padding, + past_values_fine_padding=past_values_fine_padding, + freq=freq, + use_cache=bool(use_cache), + **kwargs, + ) + return self._incremental_forward( + past_values_fine=past_values_fine, + past_values_fine_padding=past_values_fine_padding, + freq=freq, + past_key_values=past_key_values, + cache_position=cache_position, + loc_fine=loc_fine, + scale_fine=scale_fine, + **kwargs, + ) + + def _full_forward( + self, + past_values_coarse: torch.Tensor, + past_values_fine: torch.Tensor, + past_values_coarse_padding: torch.LongTensor | None, + past_values_fine_padding: torch.LongTensor | None, + freq: torch.Tensor | None, + use_cache: bool, + **kwargs: Unpack[TransformersKwargs], + ) -> CtsmOutput: + if past_values_coarse_padding is None: + past_values_coarse_padding = torch.zeros_like(past_values_coarse) + if past_values_fine_padding is None: + past_values_fine_padding = torch.zeros_like(past_values_fine) + past_values_coarse_padding = past_values_coarse_padding.to(past_values_coarse.dtype) + past_values_fine_padding = past_values_fine_padding.to(past_values_fine.dtype) + + patch_length = self.config.patch_length + past_values_coarse, past_values_coarse_padding = self._left_pad_to_patch_boundary( + past_values_coarse, past_values_coarse_padding, patch_length + ) + past_values_fine, past_values_fine_padding = self._left_pad_to_patch_boundary( + past_values_fine, past_values_fine_padding, patch_length + ) + + coarse_normalized, loc_coarse, scale_coarse = self._normalize_with_pad( + past_values_coarse, past_values_coarse_padding, tolerance=self.config.tolerance + ) + fine_normalized, loc_fine, scale_fine = self._normalize_with_pad( + past_values_fine, past_values_fine_padding, tolerance=self.config.tolerance + ) + + coarse_embeddings, coarse_patch_padding = self._patchify(coarse_normalized, past_values_coarse_padding) + fine_embeddings, fine_patch_padding = self._patchify(fine_normalized, past_values_fine_padding) + + bsize, num_coarse_patches, hidden_size = coarse_embeddings.shape + num_fine_patches = fine_embeddings.shape[1] + device = coarse_embeddings.device + dtype = coarse_embeddings.dtype + + if self.config.use_special_token: + special = self.special_token.to(device=device, dtype=dtype).expand(bsize, 1, hidden_size) + special_padding = torch.zeros(bsize, 1, device=device, dtype=coarse_patch_padding.dtype) + model_input = torch.cat([coarse_embeddings, special, fine_embeddings], dim=1) + patch_padding = torch.cat([coarse_patch_padding, special_padding, fine_patch_padding], dim=1) + num_special = 1 + else: + model_input = torch.cat([coarse_embeddings, fine_embeddings], dim=1) + patch_padding = torch.cat([coarse_patch_padding, fine_patch_padding], dim=1) + num_special = 0 + + if self.config.use_resolution_embeddings: + mr_coarse = torch.zeros(num_coarse_patches, dtype=torch.long, device=device) + mr_special = torch.full((num_special,), 1, dtype=torch.long, device=device) + mr_fine = torch.full((num_fine_patches,), 2, dtype=torch.long, device=device) + mr_idx = torch.cat([mr_coarse, mr_special, mr_fine], dim=0).unsqueeze(0).expand(bsize, -1) + model_input = model_input + self.multi_resolution(mr_idx) + + if freq is None: + freq = torch.zeros((bsize, 1), dtype=torch.long, device=device) + else: + freq = freq.to(device=device, dtype=torch.long) + model_input = model_input + self.freq_emb(freq) + + attention_mask = self._build_attention_mask(patch_padding, num_coarse_patches, model_input.dtype) + position_ids = ( + torch.arange(model_input.shape[1], device=device, dtype=torch.long).unsqueeze(0).expand(bsize, -1) + ) + position_embeddings = self.rotary_emb(model_input, position_ids) + + past_key_values = DynamicCache(config=self.config) if use_cache else None + + hidden_states = model_input + for layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = layer( + hidden_states, + attention_mask=attention_mask, + paddings=patch_padding, + position_embeddings=position_embeddings, + past_key_values=past_key_values, + **kwargs, + ) + + return CtsmOutput( + last_hidden_state=hidden_states, + loc=loc_fine, + scale=scale_fine, + loc_coarse=loc_coarse, + scale_coarse=scale_coarse, + num_coarse_patches=num_coarse_patches + num_special, + num_fine_patches=num_fine_patches, + past_key_values=past_key_values, + ) + + def _incremental_forward( + self, + past_values_fine: torch.Tensor, + past_values_fine_padding: torch.LongTensor | None, + freq: torch.Tensor | None, + past_key_values: Cache, + cache_position: torch.LongTensor | None, + loc_fine: torch.Tensor | None, + scale_fine: torch.Tensor | None, + **kwargs: Unpack[TransformersKwargs], + ) -> CtsmOutput: + if loc_fine is None or scale_fine is None: + raise ValueError( + "`loc_fine` and `scale_fine` must be supplied together with `past_key_values` so that the new fine " + "values are normalized on the same scale as the cached ones." + ) + if past_values_fine.shape[1] % self.config.patch_length != 0: + raise ValueError( + f"In incremental mode `past_values_fine` length must be a multiple of `patch_length=" + f"{self.config.patch_length}`; got {past_values_fine.shape[1]}." + ) + + if past_values_fine_padding is None: + past_values_fine_padding = torch.zeros_like(past_values_fine) + past_values_fine_padding = past_values_fine_padding.to(past_values_fine.dtype) + + tol = self.config.tolerance + fine_normalized = (past_values_fine - loc_fine.unsqueeze(-1)) / (scale_fine.unsqueeze(-1) + tol) + fine_normalized = fine_normalized * (1.0 - past_values_fine_padding) + fine_normalized = fine_normalized.clamp(-1000.0, 1000.0) + + new_embeddings, new_patch_padding = self._patchify(fine_normalized, past_values_fine_padding) + bsize, num_new, _ = new_embeddings.shape + device = new_embeddings.device + dtype = new_embeddings.dtype + + if self.config.use_resolution_embeddings: + mr_idx = torch.full((bsize, num_new), 2, dtype=torch.long, device=device) + new_embeddings = new_embeddings + self.multi_resolution(mr_idx) + + if freq is None: + freq = torch.zeros((bsize, 1), dtype=torch.long, device=device) + else: + freq = freq.to(device=device, dtype=torch.long) + new_embeddings = new_embeddings + self.freq_emb(freq) + + past_length = past_key_values.get_seq_length() + if cache_position is None: + cache_position = torch.arange(past_length, past_length + num_new, dtype=torch.long, device=device) + position_ids = cache_position.unsqueeze(0).expand(bsize, -1) + position_embeddings = self.rotary_emb(new_embeddings, position_ids) + + attention_mask = self._build_incremental_attention_mask(bsize, num_new, past_length, dtype, device) + + hidden_states = new_embeddings + for layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = layer( + hidden_states, + attention_mask=attention_mask, + paddings=new_patch_padding, + position_embeddings=position_embeddings, + past_key_values=past_key_values, + **kwargs, + ) + + return CtsmOutput( + last_hidden_state=hidden_states, + loc=loc_fine, + scale=scale_fine, + num_fine_patches=num_new, + past_key_values=past_key_values, + ) + + +class CtsmModelForPrediction(TimesFmModelForPrediction): + """CTSM model with a multi-resolution prediction head and autoregressive multi-resolution decoding. + + For horizons that require autoregressive decoding (``horizon_len > config.horizon_length``) the + prediction class reuses a key/value cache across AR steps: the first step runs the full forward + and populates a [`DynamicCache`], subsequent steps feed only the newly-appended fine patches + through the stack and attend to the cached K/V for every earlier position. Two caveats, matching + how a KV cache is made to fit CTSM's architecture: + + * Stream-level normalization statistics (``loc_fine``, ``scale_fine``) are frozen to the values + computed on the first step. This is a small approximation: in the untracked reference, + statistics are recomputed after each prediction is appended; in practice the drift is small + when forecasts stay in-distribution. + * If an AR step would grow the coarse block (a new coarse patch is formed once every + ``patch_length * agg_factor / output_patch_len`` steps, i.e. ~every 15 steps at the defaults), + the cache is discarded and a full forward is run, rebuilding the cache. + """ + + def __init__(self, config: CtsmConfig): + super().__init__(config) + del self.decoder + del self.horizon_ff_layer + + self.model = CtsmModel(config) + num_outputs = 1 + len(config.quantiles) + self.horizon_ff_layer = CtsmResidualBlock( + input_dims=config.hidden_size, + output_dims=config.horizon_length * num_outputs, + hidden_dims=config.intermediate_size, + ) + self.post_init() + + @staticmethod + def _build_multi_resolution( + series: torch.Tensor, agg_factor: int, coarse_len: int, fine_len: int + ) -> tuple[torch.Tensor, torch.Tensor]: + """Build (coarse, fine) contexts from a 1-D fine-resolution series. + + Coarse is the mean of the last `coarse_len * agg_factor` fine samples, aligned to block boundaries. + Fine is the last `fine_len` samples. + """ + series = series.to(torch.float32).reshape(-1) + needed = coarse_len * agg_factor + raw = series[-needed:] + remainder = raw.shape[0] % agg_factor + if remainder: + raw = raw[remainder:] + if raw.numel() == 0: + coarse = series.new_empty((0,), dtype=torch.float32) + else: + coarse = raw.reshape(-1, agg_factor).mean(dim=1) + if coarse.shape[0] > coarse_len: + coarse = coarse[-coarse_len:] + fine = series[-fine_len:].to(torch.float32) + return coarse, fine + + def _prepare_context( + self, + past_values: Sequence[torch.Tensor] | Sequence[tuple[torch.Tensor, torch.Tensor]], + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + coarse_len = self.config.context_length + fine_len = self.config.context_length + agg = self.config.agg_factor + + coarse_batch = torch.zeros((len(past_values), coarse_len), dtype=torch.float32, device=device) + coarse_pad = torch.zeros_like(coarse_batch) + fine_batch = torch.zeros((len(past_values), fine_len), dtype=torch.float32, device=device) + fine_pad = torch.zeros_like(fine_batch) + + for i, item in enumerate(past_values): + if isinstance(item, (tuple, list)) and len(item) == 2: + coarse, fine = item + coarse = torch.as_tensor(coarse, dtype=torch.float32, device=device).reshape(-1) + fine = torch.as_tensor(fine, dtype=torch.float32, device=device).reshape(-1) + else: + series = torch.as_tensor(item, dtype=torch.float32, device=device).reshape(-1) + coarse, fine = self._build_multi_resolution(series, agg, coarse_len, fine_len) + + c_n = coarse.shape[0] + if c_n >= coarse_len: + coarse_batch[i] = coarse[-coarse_len:] + elif c_n > 0: + coarse_batch[i, coarse_len - c_n :] = coarse + coarse_pad[i, : coarse_len - c_n] = 1.0 + else: + coarse_pad[i] = 1.0 + + f_n = fine.shape[0] + if f_n >= fine_len: + fine_batch[i] = fine[-fine_len:] + elif f_n > 0: + fine_batch[i, fine_len - f_n :] = fine + fine_pad[i, : fine_len - f_n] = 1.0 + else: + fine_pad[i] = 1.0 + + return coarse_batch, coarse_pad, fine_batch, fine_pad + + def _project_last_fine(self, outputs: CtsmOutput, last_position: int) -> tuple[torch.Tensor, torch.Tensor]: + """Project the hidden state at `last_position` through the horizon head and denormalize.""" + last_hidden = outputs.last_hidden_state[:, last_position : last_position + 1, :] + head = self.horizon_ff_layer(last_hidden) + bsize = head.shape[0] + num_outputs = 1 + len(self.config.quantiles) + head = head.view(bsize, self.config.horizon_length, num_outputs) + + loc = outputs.loc[:, None, None] + scale = outputs.scale[:, None, None] + mean_patch = head[..., 0] * scale[..., 0] + loc[..., 0] + quant_patch = head[..., 1:] * scale + loc + mean_patch = torch.nan_to_num(mean_patch, nan=0.0, posinf=0.0, neginf=0.0) + quant_patch = torch.nan_to_num(quant_patch, nan=0.0, posinf=0.0, neginf=0.0) + return mean_patch, quant_patch + + def _decode_step_full( + self, + past_values_coarse: torch.Tensor, + past_values_fine: torch.Tensor, + past_values_coarse_padding: torch.Tensor, + past_values_fine_padding: torch.Tensor, + freq: torch.Tensor, + use_cache: bool, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor, CtsmOutput]: + """Full forward through the model. If `use_cache`, the returned outputs carry a fresh cache.""" + outputs: CtsmOutput = self.model( + past_values_coarse=past_values_coarse, + past_values_fine=past_values_fine, + past_values_coarse_padding=past_values_coarse_padding, + past_values_fine_padding=past_values_fine_padding, + freq=freq, + use_cache=use_cache, + **kwargs, + ) + mean_patch, quant_patch = self._project_last_fine(outputs, outputs.last_hidden_state.shape[1] - 1) + return mean_patch, quant_patch, outputs + + def _decode_step_incremental( + self, + new_fine_values: torch.Tensor, + freq: torch.Tensor, + past_key_values: Cache, + loc_fine: torch.Tensor, + scale_fine: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor, CtsmOutput]: + """Append `new_fine_values` to the cached state and run only the new positions through the stack.""" + outputs: CtsmOutput = self.model( + past_values_fine=new_fine_values, + freq=freq, + past_key_values=past_key_values, + loc_fine=loc_fine, + scale_fine=scale_fine, + **kwargs, + ) + mean_patch, quant_patch = self._project_last_fine(outputs, outputs.last_hidden_state.shape[1] - 1) + return mean_patch, quant_patch, outputs + + @can_return_tuple + @auto_docstring + def forward( + self, + past_values: Sequence[torch.Tensor] | Sequence[tuple[torch.Tensor, torch.Tensor]], + future_values: torch.Tensor | None = None, + horizon_len: int | None = None, + freq: Sequence[int] | torch.Tensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> CtsmOutputForPrediction: + r""" + past_values (`Sequence[torch.Tensor]`): + Either a list of 1-D fine-resolution tensors (the coarse stream is derived by mean-aggregating over + `agg_factor` consecutive points) or a list of `(coarse, fine)` pairs if both streams are provided. + future_values (`torch.Tensor`, *optional*): + Optional fine-resolution ground truth used to compute the loss. + horizon_len (`int`, *optional*): + Number of fine-resolution steps to forecast. Defaults to `config.horizon_length`. Values larger than + `config.horizon_length` trigger autoregressive decoding. + freq (`Sequence[int]` or `torch.Tensor`, *optional*): + Frequency indices. Defaults to zeros. + use_cache (`bool`, *optional*): + Whether to use a key/value cache across autoregressive decode steps. Defaults to `True` when + `horizon_len > config.horizon_length` (i.e. when AR decoding is needed) and `False` otherwise. + Set to `False` to force a full recompute at every AR step (matches the original reference + behaviour; slower but avoids the stream-stats-freezing approximation). + """ + device = self.horizon_ff_layer.input_layer.weight.device + horizon_len = horizon_len or self.config.horizon_length + if horizon_len <= 0: + raise ValueError("horizon_len must be positive") + + output_patch_len = self.config.horizon_length + num_decode_patches = (horizon_len + output_patch_len - 1) // output_patch_len + + coarse, coarse_pad, fine, fine_pad = self._prepare_context(past_values, device=device) + bsize = coarse.shape[0] + + if freq is None: + freq_tensor = torch.zeros((bsize, 1), dtype=torch.long, device=device) + else: + freq_tensor = torch.as_tensor( + list(freq) if not isinstance(freq, torch.Tensor) else freq, dtype=torch.long, device=device + ).view(bsize, 1) + + mean_chunks: list[torch.Tensor] = [] + quant_chunks: list[torch.Tensor] = [] + remaining = horizon_len + last_outputs: CtsmOutput | None = None + max_fine = self.config.context_length + max_coarse = self.config.context_length + agg = self.config.agg_factor + new_fine_patches = self.config.horizon_length // self.config.patch_length + + past_key_values: Cache | None = None + frozen_loc_fine: torch.Tensor | None = None + frozen_scale_fine: torch.Tensor | None = None + coarse_buffer = torch.zeros((bsize, 0), dtype=torch.float32, device=device) + + if use_cache is None: + use_cache = num_decode_patches > 1 + pending_new_fine: torch.Tensor | None = None + + for step_idx in range(num_decode_patches): + if past_key_values is None: + # First step (or after cache invalidation): full forward. The coarse block in the cache + # stays frozen at the initial state — only the fine block grows via subsequent incremental + # steps — which matches how KV caches work for append-only sequences. + mean_patch, quant_patch, last_outputs = self._decode_step_full( + past_values_coarse=coarse, + past_values_fine=fine, + past_values_coarse_padding=coarse_pad, + past_values_fine_padding=fine_pad, + freq=freq_tensor, + use_cache=use_cache, + **kwargs, + ) + past_key_values = last_outputs.past_key_values + frozen_loc_fine = last_outputs.loc + frozen_scale_fine = last_outputs.scale + else: + # Incremental: only the fine values newly appended last step go through the stack. + mean_patch, quant_patch, last_outputs = self._decode_step_incremental( + new_fine_values=pending_new_fine, + freq=freq_tensor, + past_key_values=past_key_values, + loc_fine=frozen_loc_fine, + scale_fine=frozen_scale_fine, + **kwargs, + ) + + take = min(remaining, output_patch_len) + mean_chunks.append(mean_patch[:, :take]) + quant_chunks.append(quant_patch[:, :take, :]) + remaining -= take + if remaining <= 0: + break + + new_fine = mean_patch[:, :output_patch_len] + pending_new_fine = new_fine + + # Track the raw contexts so the next full-forward (initial step or after cache + # invalidation) sees the right state. Mirrors the reference AR loop. + fine = torch.cat([fine, new_fine], dim=1) + fine_pad = torch.cat( + [fine_pad, torch.zeros((bsize, output_patch_len), device=device, dtype=fine_pad.dtype)], dim=1 + ) + if fine.shape[1] > max_fine: + fine = fine[:, -max_fine:] + fine_pad = fine_pad[:, -max_fine:] + + coarse_buffer = torch.cat([coarse_buffer, new_fine], dim=1) + full_blocks = coarse_buffer.shape[1] // agg + if full_blocks > 0: + blocks = coarse_buffer[:, : full_blocks * agg].view(bsize, full_blocks, agg).mean(dim=2) + coarse_buffer = coarse_buffer[:, full_blocks * agg :] + coarse = torch.cat([coarse, blocks], dim=1) + coarse_pad = torch.cat( + [coarse_pad, torch.zeros((bsize, full_blocks), device=device, dtype=coarse_pad.dtype)], dim=1 + ) + if coarse.shape[1] > max_coarse: + coarse = coarse[:, -max_coarse:] + coarse_pad = coarse_pad[:, -max_coarse:] + + if past_key_values is not None: + projected_len = past_key_values.get_seq_length() + new_fine_patches + if projected_len >= self.config.max_position_embeddings: + past_key_values = None + pending_new_fine = None + + mean_predictions = torch.cat(mean_chunks, dim=1)[:, :horizon_len] + full_predictions = torch.cat( + [torch.cat(mean_chunks, dim=1)[:, :horizon_len, None], torch.cat(quant_chunks, dim=1)[:, :horizon_len, :]], + dim=-1, + ) + + loss = None + if future_values is not None: + target_len = min(future_values.shape[1], mean_predictions.shape[1]) + mse_loss = F.mse_loss(mean_predictions[:, :target_len], future_values[:, :target_len]) + quantile_loss = self._quantile_loss(full_predictions[:, :target_len, 1:], future_values[:, :target_len]) + loss = mse_loss + quantile_loss + + return CtsmOutputForPrediction( + last_hidden_state=last_outputs.last_hidden_state if last_outputs is not None else None, + hidden_states=last_outputs.hidden_states if last_outputs is not None else None, + attentions=last_outputs.attentions if last_outputs is not None else None, + mean_predictions=mean_predictions, + full_predictions=full_predictions, + loss=loss, + ) + + +__all__ = [ + "CtsmConfig", + "CtsmModel", + "CtsmModelForPrediction", + "CtsmPreTrainedModel", +] diff --git a/tests/models/ctsm/__init__.py b/tests/models/ctsm/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/ctsm/test_modeling_ctsm.py b/tests/models/ctsm/test_modeling_ctsm.py new file mode 100644 index 000000000000..fa37870a9a3a --- /dev/null +++ b/tests/models/ctsm/test_modeling_ctsm.py @@ -0,0 +1,294 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from transformers import CtsmConfig, is_torch_available +from transformers.testing_utils import require_flash_attn, require_torch, require_torch_accelerator, slow, torch_device + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION, ModelTesterMixin, floats_tensor + + +if is_torch_available(): + from transformers import CtsmModel, CtsmModelForPrediction + + +class CtsmModelTester: + def __init__( + self, + parent, + patch_length: int = 8, + context_length: int = 64, + horizon_length: int = 8, + num_hidden_layers: int = 2, + hidden_size: int = 32, + intermediate_size: int = 32, + head_dim: int = 16, + num_attention_heads: int = 2, + num_key_value_heads: int = 2, + rms_norm_eps: float = 1e-6, + quantiles=(0.1, 0.5, 0.9), + agg_factor: int = 4, + max_position_embeddings: int = 64, + batch_size: int = 2, + is_training: bool = True, + ): + self.parent = parent + self.patch_length = patch_length + self.context_length = context_length + self.horizon_length = horizon_length + self.num_hidden_layers = num_hidden_layers + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.head_dim = head_dim + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.rms_norm_eps = rms_norm_eps + self.quantiles = list(quantiles) + self.agg_factor = agg_factor + self.max_position_embeddings = max_position_embeddings + self.batch_size = batch_size + self.is_training = is_training + + # Total patches in the concatenated sequence (coarse + special + fine). + self.seq_length = 2 * (context_length // patch_length) + 1 + + def get_config(self): + return CtsmConfig( + patch_length=self.patch_length, + context_length=self.context_length, + horizon_length=self.horizon_length, + num_hidden_layers=self.num_hidden_layers, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + head_dim=self.head_dim, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + rms_norm_eps=self.rms_norm_eps, + quantiles=self.quantiles, + agg_factor=self.agg_factor, + max_position_embeddings=self.max_position_embeddings, + ) + + def get_pipeline_config(self): + return self.get_config() + + def prepare_config_and_inputs(self): + bsize = self.batch_size + past_values = [ + torch.tensor( + np.sin(np.linspace(0, 20, self.agg_factor * self.context_length)), + dtype=torch.float32, + device=torch_device, + ) + for _ in range(bsize) + ] + return self.get_config(), past_values + + def prepare_config_and_inputs_for_common(self): + config, past_values = self.prepare_config_and_inputs() + inputs_dict = {"past_values": past_values} + return config, inputs_dict + + +@require_torch +class CtsmModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (CtsmModelForPrediction,) if is_torch_available() else () + test_resize_embeddings = False + is_encoder_decoder = False + test_inputs_embeds = False + test_all_params_have_gradient = False + test_headmasking = False + test_pruning = False + test_missing_keys = False + test_model_parallel = False + + def setUp(self): + self.model_tester = CtsmModelTester(self) + self.config_tester = ConfigTester(self, config_class=CtsmConfig, has_text_modality=False) + + def test_create_and_run_model(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = CtsmModelForPrediction(config) + model.to(torch_device) + model.eval() + results = model(**inputs_dict) + self.assertEqual(results.mean_predictions.shape, (self.model_tester.batch_size, config.horizon_length)) + self.assertEqual( + results.full_predictions.shape, + (self.model_tester.batch_size, config.horizon_length, 1 + len(config.quantiles)), + ) + + def test_encoder_forward_matches_predict(self): + """The low-level `CtsmModel.forward` should accept the two-stream interface directly.""" + config = self.model_tester.get_config() + model = CtsmModel(config).to(torch_device).eval() + + coarse = torch.randn(self.model_tester.batch_size, config.context_length, device=torch_device) + fine = torch.randn(self.model_tester.batch_size, config.context_length, device=torch_device) + with torch.no_grad(): + out = model(past_values_coarse=coarse, past_values_fine=fine) + + coarse_patches = config.context_length // config.patch_length + fine_patches = config.context_length // config.patch_length + self.assertEqual( + out.last_hidden_state.shape, + (self.model_tester.batch_size, coarse_patches + 1 + fine_patches, config.hidden_size), + ) + self.assertEqual(out.loc.shape, (self.model_tester.batch_size,)) + self.assertEqual(out.loc_coarse.shape, (self.model_tester.batch_size,)) + + @unittest.skip(reason="CTSM uses a custom multi-resolution attention mask built internally.") + def test_sdpa_can_dispatch_on_flash(self): + pass + + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) + def test_eager_matches_sdpa_inference( + self, name, dtype, padding_side, use_attention_mask, output_attentions, enable_kernels + ): + """CTSM builds its own mask from the concatenated stream paddings; the generic harness, which + injects external attention masks and mutates QK-norm RMSNorm eps, is not compatible. We verify + eager vs. SDPA equivalence on the low-level `CtsmModel` instead.""" + if not self.all_model_classes[0]._supports_sdpa: + self.skipTest("Model does not support SDPA") + torch_dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}[dtype] + tolerance = {torch.float32: 1e-4, torch.bfloat16: 5e-3, torch.float16: 5e-3}[torch_dtype] + self._attn_kernel_equivalence("sdpa", dtype=torch_dtype, tolerance=tolerance) + + @unittest.skip(reason="Model does not have input embeddings") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="CTSM does not support gradient checkpointing in this version") + def test_gradient_checkpointing_backward_compatibility(self): + pass + + def _attn_kernel_equivalence(self, attn_implementation, dtype=torch.float32, tolerance=1e-4): + """Compare eager vs `attn_implementation` on the low-level `CtsmModel`. + + Uses the two-stream interface so we bypass the prediction-head AR loop which + adds numerical noise unrelated to the kernel choice. + """ + config = self.model_tester.get_config() + model_eager = CtsmModel._from_config(config, attn_implementation="eager") + model_eager.to(dtype=dtype, device=torch_device).eval() + + model_other = CtsmModel._from_config(config, attn_implementation=attn_implementation) + model_other.load_state_dict(model_eager.state_dict()) + model_other.to(dtype=dtype, device=torch_device).eval() + + coarse = torch.randn(self.model_tester.batch_size, config.context_length, device=torch_device, dtype=dtype) + fine = torch.randn(self.model_tester.batch_size, config.context_length, device=torch_device, dtype=dtype) + + with torch.no_grad(): + out_e = model_eager(past_values_coarse=coarse, past_values_fine=fine) + out_o = model_other(past_values_coarse=coarse, past_values_fine=fine) + + diff = (out_e.last_hidden_state - out_o.last_hidden_state).abs().max().item() + self.assertLess(diff, tolerance, f"{attn_implementation} vs eager last_hidden_state max diff: {diff:.2e}") + + def test_eager_matches_sdpa(self): + if not self.all_model_classes[0]._supports_sdpa: + self.skipTest("Model does not support SDPA") + self._attn_kernel_equivalence("sdpa", dtype=torch.float32, tolerance=1e-4) + + @require_flash_attn + @require_torch_accelerator + def test_flash_attn_2_inference_equivalence(self): + self._attn_kernel_equivalence("flash_attention_2", dtype=torch.bfloat16, tolerance=1e-2) + + def test_retain_grad_hidden_states_attentions(self): + """CTSM returns `mean_predictions` as the first tensor, not `last_hidden_state`.""" + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + config.output_attentions = self.has_attentions + if self.has_attentions: + config._attn_implementation = "eager" + + model_class = self.all_model_classes[0] + model = model_class._from_config(config, attn_implementation="eager") + model.to(torch_device) + inputs = self._prepare_for_class(inputs_dict, model_class) + outputs = model(**inputs) + + output_tensor = outputs.mean_predictions + if outputs.hidden_states is not None: + hidden_states = outputs.hidden_states[0] + hidden_states.retain_grad() + if self.has_attentions and outputs.attentions is not None: + attentions = outputs.attentions[0] + attentions.retain_grad() + + output_tensor.flatten()[0].backward(retain_graph=True) + + if outputs.hidden_states is not None: + self.assertIsNotNone(hidden_states.grad) + if self.has_attentions and outputs.attentions is not None: + self.assertIsNotNone(attentions.grad) + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + if return_labels: + batch_size = len(inputs_dict["past_values"]) + rng = random.Random(42) + inputs_dict["future_values"] = floats_tensor([batch_size, self.model_tester.horizon_length], rng=rng) + return inputs_dict + + def test_kv_cache_matches_full_recompute(self): + """Cached autoregressive decoding should produce close-to-identical predictions to the + full-recompute path (the small gap is from the stream-stats-freezing approximation).""" + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = CtsmModelForPrediction(config).to(torch_device).eval() + + # Long enough to trigger AR (horizon > config.horizon_length). + horizon_len = config.horizon_length * 3 + with torch.no_grad(): + out_full = model(**inputs_dict, horizon_len=horizon_len, use_cache=False) + out_cache = model(**inputs_dict, horizon_len=horizon_len, use_cache=True) + + # First horizon_length predictions must match bit-exactly (step 1 is identical in both paths). + step1 = config.horizon_length + self.assertTrue( + torch.allclose(out_full.mean_predictions[:, :step1], out_cache.mean_predictions[:, :step1], atol=1e-5), + msg="Step-1 predictions must match bit-exactly between cached and non-cached paths.", + ) + # On subsequent AR steps the stats-freezing approximation introduces a small bounded drift. + # The bound is generous here because the tiny tester model has random weights and a horizon of 8, + # so compounding any small per-step shift over multiple steps is amplified. + relative = (out_full.mean_predictions - out_cache.mean_predictions).abs().max() / ( + out_full.mean_predictions.abs().max().clamp_min(1.0) + ) + self.assertLess(relative.item(), 0.5, f"cached vs full-recompute AR drift {relative.item():.2e} too large") + + +@require_torch +@slow +class CtsmModelIntegrationTests(unittest.TestCase): + def test_inference(self): + model = CtsmModelForPrediction.from_pretrained("cisco-ai/cisco-time-series-model-1.0").to(torch_device) + rng = np.random.default_rng(42) + series = (np.sin(np.linspace(0, 200, 512 * 60)) + 0.05 * rng.standard_normal(512 * 60)).astype(np.float32) + past_values = [torch.tensor(series, device=torch_device)] + + with torch.no_grad(): + output = model(past_values=past_values, horizon_len=128) + + self.assertEqual(output.mean_predictions.shape, (1, 128)) + self.assertEqual(output.full_predictions.shape, (1, 128, 1 + len(model.config.quantiles))) diff --git a/utils/check_repo.py b/utils/check_repo.py index c4b8e44b4dd8..2738bdc06540 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -251,6 +251,7 @@ "PPDocLayoutV3Model", # Building part of bigger (tested) model "TimesFmModel", # Building part of bigger (tested) model "TimesFm2_5Model", # Building part of bigger (tested) model + "CtsmModel", # Building part of bigger (tested) model "CsmDepthDecoderForCausalLM", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest. "CsmDepthDecoderModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest. "CsmBackboneModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest.