From 33b0c0d075ec6f7e57c4a54befb877259e685402 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 17 Apr 2026 12:16:42 +0200 Subject: [PATCH 1/9] Add Cisco Time Series Model (CTSM) 1.0 Adds CTSM 1.0 (cisco-ai/cisco-time-series-model-1.0) as a first-class time-series foundation model. It is architecturally a TimesFM 2.0 decoder with multi-resolution inputs (coarse + learned special token + fine), rotary position embeddings, bidirectional attention over the coarse block, and 15-quantile prediction. - modular_ctsm.py reuses TimesFmAttention/DecoderLayer/Model and the TimesFm2_5 RoPE utilities so RoPE + per-dim Q scaling are shared. - CtsmModel.forward takes (past_values_coarse, past_values_fine) streams. CtsmModelForPrediction.forward takes a list of fine-res series and derives the coarse stream by mean-aggregation over agg_factor blocks, then runs an AR decode loop. - Registered in auto_mappings, MODEL_MAPPING, time-series-prediction mapping, models/__init__.py, _toctree.yml, and docs. - Tests mirror the timesfm2_5 pattern: full ModelTesterMixin coverage (with a custom eager-vs-SDPA equivalence that uses the native two-stream interface since CTSM builds its own mask). - Conversion script maps the fused qkv_proj + input/horizon residual blocks + multi_resolution / special_token / freq_emb to the transformers layout and has been verified end-to-end against the 250M Hub checkpoint. --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/ctsm.md | 81 ++ src/transformers/models/__init__.py | 1 + src/transformers/models/auto/auto_mappings.py | 1 + src/transformers/models/auto/modeling_auto.py | 2 + src/transformers/models/ctsm/__init__.py | 28 + .../models/ctsm/configuration_ctsm.py | 118 ++ .../ctsm/convert_ctsm_original_to_hf.py | 209 ++++ src/transformers/models/ctsm/modeling_ctsm.py | 1097 +++++++++++++++++ src/transformers/models/ctsm/modular_ctsm.py | 689 +++++++++++ tests/models/ctsm/__init__.py | 0 tests/models/ctsm/test_modeling_ctsm.py | 268 ++++ 12 files changed, 2496 insertions(+) create mode 100644 docs/source/en/model_doc/ctsm.md create mode 100644 src/transformers/models/ctsm/__init__.py create mode 100644 src/transformers/models/ctsm/configuration_ctsm.py create mode 100644 src/transformers/models/ctsm/convert_ctsm_original_to_hf.py create mode 100644 src/transformers/models/ctsm/modeling_ctsm.py create mode 100644 src/transformers/models/ctsm/modular_ctsm.py create mode 100644 tests/models/ctsm/__init__.py create mode 100644 tests/models/ctsm/test_modeling_ctsm.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index a31944a3ef69..537a3cf198f0 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1407,6 +1407,8 @@ - sections: - local: model_doc/autoformer title: Autoformer + - local: model_doc/ctsm + title: CTSM - local: model_doc/informer title: Informer - local: model_doc/patchtsmixer diff --git a/docs/source/en/model_doc/ctsm.md b/docs/source/en/model_doc/ctsm.md new file mode 100644 index 000000000000..372a038b839e --- /dev/null +++ b/docs/source/en/model_doc/ctsm.md @@ -0,0 +1,81 @@ + +*This model was released on 2025-11-25 and added to Hugging Face Transformers on 2026-04-17.* + +
+
+ PyTorch + FlashAttention + SDPA +
+
+ +# CTSM + +## Overview + +The Cisco Time Series Model (CTSM) 1.0 is a 250M-parameter decoder-only foundation model for univariate zero-shot +forecasting, proposed in [Cisco Time Series Model Technical Report](https://huggingface.co/papers/2511.19841) by +Liang Gou et al. It is architecturally inspired by [TimesFM 2.0](https://huggingface.co/google/timesfm-2.0-500m-pytorch) +and adds a multi-resolution context (a coarse stream aggregated by a configurable `agg_factor`, a learned special +token, and a fine stream), rotary position embeddings, bidirectional attention over the coarse-resolution block, +15-quantile prediction, and per-resolution learned embeddings. + +The checkpoint can be found at [`cisco-ai/cisco-time-series-model-1.0`](https://huggingface.co/cisco-ai/cisco-time-series-model-1.0). + +## Usage example + +```python +import numpy as np +import torch +from transformers import CtsmModelForPrediction + + +model = CtsmModelForPrediction.from_pretrained("cisco-ai/cisco-time-series-model-1.0", device_map="auto") + +# A fine-resolution (e.g. minute-level) time series. The coarse stream is built automatically +# by mean-aggregating consecutive blocks of `config.agg_factor` points. +series = np.sin(np.linspace(0, 200, 512 * 60)).astype(np.float32) +past_values = [torch.tensor(series, device=model.device)] + +with torch.no_grad(): + outputs = model(past_values=past_values, horizon_len=128) + +point_forecast = outputs.mean_predictions # (batch, horizon_len) +quantile_forecast = outputs.full_predictions # (batch, horizon_len, 1 + num_quantiles) +``` + +You can also pass `(coarse, fine)` pairs directly if you already have the coarse stream: + +```python +coarse = torch.tensor(coarse_series, dtype=torch.float32) +fine = torch.tensor(fine_series, dtype=torch.float32) +outputs = model(past_values=[(coarse, fine)], horizon_len=128) +``` + +## CtsmConfig + +[[autodoc]] CtsmConfig + +## CtsmModel + +[[autodoc]] CtsmModel + - forward + +## CtsmModelForPrediction + +[[autodoc]] CtsmModelForPrediction + - forward diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index acc5e2fdeac0..89e35fbd4e8c 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -81,6 +81,7 @@ from .cpmant import * from .csm import * from .ctrl import * + from .ctsm import * from .cvt import * from .cwm import * from .d_fine import * diff --git a/src/transformers/models/auto/auto_mappings.py b/src/transformers/models/auto/auto_mappings.py index 98c40e5a891b..efce4e02a8f2 100644 --- a/src/transformers/models/auto/auto_mappings.py +++ b/src/transformers/models/auto/auto_mappings.py @@ -108,6 +108,7 @@ ("csm", "CsmConfig"), ("csm_depth_decoder_model", "CsmDepthDecoderConfig"), ("ctrl", "CTRLConfig"), + ("ctsm", "CtsmConfig"), ("cvt", "CvtConfig"), ("cwm", "CwmConfig"), ("d_fine", "DFineConfig"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 44d83634d28e..d12f14b9615d 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -99,6 +99,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("cpmant", "CpmAntModel"), ("csm", "CsmForConditionalGeneration"), ("ctrl", "CTRLModel"), + ("ctsm", "CtsmModel"), ("cvt", "CvtModel"), ("cwm", "CwmModel"), ("d_fine", "DFineModel"), @@ -1811,6 +1812,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES = OrderedDict( [ + ("ctsm", "CtsmModelForPrediction"), ("timesfm", "TimesFmModelForPrediction"), ("timesfm2_5", "TimesFm2_5ModelForPrediction"), ] diff --git a/src/transformers/models/ctsm/__init__.py b/src/transformers/models/ctsm/__init__.py new file mode 100644 index 000000000000..e5979d7ba3db --- /dev/null +++ b/src/transformers/models/ctsm/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2026 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_ctsm import * + from .modeling_ctsm import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/ctsm/configuration_ctsm.py b/src/transformers/models/ctsm/configuration_ctsm.py new file mode 100644 index 000000000000..f2f62a97d0a7 --- /dev/null +++ b/src/transformers/models/ctsm/configuration_ctsm.py @@ -0,0 +1,118 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/ctsm/modular_ctsm.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_ctsm.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from huggingface_hub.dataclasses import strict + +from ...configuration_utils import PreTrainedConfig +from ...modeling_rope_utils import RopeParameters +from ...utils import auto_docstring + + +@auto_docstring(checkpoint="cisco-ai/cisco-time-series-model-1.0") +@strict +class CtsmConfig(PreTrainedConfig): + r""" + patch_length (`int`, *optional*, defaults to 32): + Length of one patch in the input sequence for each resolution stream. + context_length (`int`, *optional*, defaults to 512): + Length of the input context for each resolution stream. + horizon_length (`int`, *optional*, defaults to 128): + Length of the prediction horizon produced per autoregressive step. + freq_size (`int`, *optional*, defaults to 3): + Number of frequency embeddings. + tolerance (`float`, *optional*, defaults to 1e-06): + Numerical tolerance used in normalization. + pad_val (`float`, *optional*, defaults to 1123581321.0): + Sentinel value marking padded positions in the input series. + num_hidden_layers (`int`, *optional*, defaults to 25): + Number of decoder layers. + quantiles (`list[float]`, *optional*, defaults to 15 values between 0.01 and 0.99): + Quantile levels predicted by the model. + use_positional_embedding (`bool`, *optional*, defaults to `False`): + CTSM uses rotary position embeddings and does not add sinusoidal positional embeddings. + use_resolution_embeddings (`bool`, *optional*, defaults to `True`): + Whether to add a learned embedding per resolution bucket (coarse / special / fine). + use_special_token (`bool`, *optional*, defaults to `True`): + Whether to insert a learned special token between the coarse and fine streams. + num_resolutions (`int`, *optional*, defaults to 3): + Number of resolution embeddings (coarse, special token, fine). + agg_factor (`int`, *optional*, defaults to 60): + Aggregation factor between fine and coarse resolutions (e.g. 60 minutes -> 1 hour). + max_position_embeddings (`int`, *optional*, defaults to 1025): + Maximum number of patches in the concatenated sequence (coarse + special + fine). + rope_parameters (`dict`, *optional*): + Rotary position embedding parameters. Defaults to `{"rope_type": "default", "rope_theta": 10000.0}`. + + Example: + + ```python + >>> from transformers import CtsmConfig, CtsmModelForPrediction + + >>> configuration = CtsmConfig() + >>> model = CtsmModelForPrediction(configuration) + >>> configuration = model.config + ``` + """ + + model_type = "ctsm" + keys_to_ignore_at_inference = [] + is_encoder_decoder = False + + patch_length: int = 32 + context_length: int = 512 + horizon_length: int = 128 + freq_size: int = 3 + + num_hidden_layers: int = 25 + hidden_size: int = 1280 + intermediate_size: int = 1280 + head_dim: int = 80 + num_attention_heads: int = 16 + tolerance: float = 1e-6 + rms_norm_eps: float = 1e-6 + quantiles: list[float] | tuple[float, ...] = ( + 0.01, + 0.05, + 0.1, + 0.2, + 0.25, + 0.3, + 0.4, + 0.5, + 0.6, + 0.7, + 0.75, + 0.8, + 0.9, + 0.95, + 0.99, + ) + pad_val: float = 1123581321.0 + attention_dropout: float | int = 0.0 + use_positional_embedding: bool = False + initializer_range: float = 0.02 + use_resolution_embeddings: bool = True + use_special_token: bool = True + num_resolutions: int = 3 + agg_factor: int = 60 + max_position_embeddings: int = 1025 + rope_parameters: RopeParameters | dict | None = None + + +__all__ = ["CtsmConfig"] diff --git a/src/transformers/models/ctsm/convert_ctsm_original_to_hf.py b/src/transformers/models/ctsm/convert_ctsm_original_to_hf.py new file mode 100644 index 000000000000..0b618547b387 --- /dev/null +++ b/src/transformers/models/ctsm/convert_ctsm_original_to_hf.py @@ -0,0 +1,209 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert a Cisco Time Series Model (CTSM) 1.0 checkpoint to the transformers format. + +Sample usage: + +``` +python src/transformers/models/ctsm/convert_ctsm_original_to_hf.py \ + --output_dir /output/path \ + --huggingface_repo_id cisco-ai/cisco-time-series-model-1.0 +``` +""" + +import argparse +import os + +import torch +from huggingface_hub import snapshot_download + +from transformers import CtsmConfig, CtsmModelForPrediction + + +CTSM_CHECKPOINT_FILENAME = "torch_model.pt" + +# CTSM 1.0 public checkpoint ships 15 quantiles spanning [0.01, 0.99]. +CTSM_1_0_QUANTILES = [0.01, 0.05, 0.1, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.75, 0.8, 0.9, 0.95, 0.99] + + +def _layer_mapping(num_layers: int, hidden_size: int) -> dict[str, str | tuple[str, int]]: + """Return a mapping `old_key -> new_key` (or `(new_prefix, split_idx)` for fused QKV).""" + mapping: dict[str, str | tuple[str, int]] = { + # input tokenizer (residual block) + "input_ff_layer.hidden_layer.0.weight": "model.input_ff_layer.input_layer.weight", + "input_ff_layer.hidden_layer.0.bias": "model.input_ff_layer.input_layer.bias", + "input_ff_layer.output_layer.weight": "model.input_ff_layer.output_layer.weight", + "input_ff_layer.output_layer.bias": "model.input_ff_layer.output_layer.bias", + "input_ff_layer.residual_layer.weight": "model.input_ff_layer.residual_layer.weight", + "input_ff_layer.residual_layer.bias": "model.input_ff_layer.residual_layer.bias", + # frequency, resolution and special token embeddings + "freq_emb.weight": "model.freq_emb.weight", + "multi_resolution.weight": "model.multi_resolution.weight", + "special_token": "model.special_token", + # horizon head (residual block) + "horizon_ff_layer.hidden_layer.0.weight": "horizon_ff_layer.input_layer.weight", + "horizon_ff_layer.hidden_layer.0.bias": "horizon_ff_layer.input_layer.bias", + "horizon_ff_layer.output_layer.weight": "horizon_ff_layer.output_layer.weight", + "horizon_ff_layer.output_layer.bias": "horizon_ff_layer.output_layer.bias", + "horizon_ff_layer.residual_layer.weight": "horizon_ff_layer.residual_layer.weight", + "horizon_ff_layer.residual_layer.bias": "horizon_ff_layer.residual_layer.bias", + } + + layer_template = { + # fused qkv -> split into q, k, v below + "stacked_transformer.layers.{i}.self_attn.qkv_proj.weight": ("model.layers.{i}.self_attn", "qkv_weight"), + "stacked_transformer.layers.{i}.self_attn.qkv_proj.bias": ("model.layers.{i}.self_attn", "qkv_bias"), + "stacked_transformer.layers.{i}.self_attn.o_proj.weight": "model.layers.{i}.self_attn.o_proj.weight", + "stacked_transformer.layers.{i}.self_attn.o_proj.bias": "model.layers.{i}.self_attn.o_proj.bias", + "stacked_transformer.layers.{i}.self_attn.scaling": "model.layers.{i}.self_attn.scaling", + "stacked_transformer.layers.{i}.mlp.gate_proj.weight": "model.layers.{i}.mlp.gate_proj.weight", + "stacked_transformer.layers.{i}.mlp.gate_proj.bias": "model.layers.{i}.mlp.gate_proj.bias", + "stacked_transformer.layers.{i}.mlp.down_proj.weight": "model.layers.{i}.mlp.down_proj.weight", + "stacked_transformer.layers.{i}.mlp.down_proj.bias": "model.layers.{i}.mlp.down_proj.bias", + "stacked_transformer.layers.{i}.mlp.layer_norm.weight": "model.layers.{i}.mlp.layer_norm.weight", + "stacked_transformer.layers.{i}.mlp.layer_norm.bias": "model.layers.{i}.mlp.layer_norm.bias", + "stacked_transformer.layers.{i}.input_layernorm.weight": "model.layers.{i}.input_layernorm.weight", + } + for i in range(num_layers): + for old, new in layer_template.items(): + mapping[old.format(i=i)] = new.format(i=i) if isinstance(new, str) else (new[0].format(i=i), new[1]) + return mapping + + +def convert_state_dict(original_sd: dict[str, torch.Tensor], hidden_size: int) -> dict[str, torch.Tensor]: + """Rewrite the original CTSM state dict into the transformers key layout.""" + num_layers = 0 + for key in original_sd: + if key.startswith("stacked_transformer.layers."): + idx = int(key.split(".")[2]) + num_layers = max(num_layers, idx + 1) + if num_layers == 0: + raise ValueError("No transformer layers found in the original checkpoint.") + + mapping = _layer_mapping(num_layers, hidden_size) + new_sd: dict[str, torch.Tensor] = {} + missing: list[str] = [] + for old_key, target in mapping.items(): + if old_key not in original_sd: + missing.append(old_key) + continue + tensor = original_sd[old_key] + if isinstance(target, tuple): + prefix, kind = target + if kind == "qkv_weight": + q, k, v = tensor.split(hidden_size, dim=0) + new_sd[f"{prefix}.q_proj.weight"] = q.clone() + new_sd[f"{prefix}.k_proj.weight"] = k.clone() + new_sd[f"{prefix}.v_proj.weight"] = v.clone() + elif kind == "qkv_bias": + q, k, v = tensor.split(hidden_size, dim=0) + new_sd[f"{prefix}.q_proj.bias"] = q.clone() + new_sd[f"{prefix}.k_proj.bias"] = k.clone() + new_sd[f"{prefix}.v_proj.bias"] = v.clone() + else: + raise ValueError(f"Unknown fused projection kind: {kind}") + else: + new_sd[target] = tensor.clone() + if missing: + print(f"[warn] {len(missing)} expected key(s) missing from the original checkpoint (first 5): {missing[:5]}") + return new_sd + + +def _infer_config_from_state_dict(original_sd: dict[str, torch.Tensor]) -> CtsmConfig: + """Infer a `CtsmConfig` from an original CTSM 1.0 state dict.""" + num_layers = 1 + max( + (int(k.split(".")[2]) for k in original_sd if k.startswith("stacked_transformer.layers.")), + default=-1, + ) + hidden_size = original_sd["input_ff_layer.output_layer.weight"].shape[0] + qkv_out = original_sd["stacked_transformer.layers.0.self_attn.qkv_proj.weight"].shape[0] + # qkv is [3 * num_heads * head_dim, hidden_size] — split evenly. + num_heads = 16 + head_dim = qkv_out // (3 * num_heads) + horizon_out = original_sd["horizon_ff_layer.output_layer.weight"].shape[0] + horizon_length = 128 + num_outputs = horizon_out // horizon_length + quantiles = ( + CTSM_1_0_QUANTILES if num_outputs - 1 == len(CTSM_1_0_QUANTILES) else [0.1 * i for i in range(1, num_outputs)] + ) + + return CtsmConfig( + num_hidden_layers=num_layers, + hidden_size=hidden_size, + intermediate_size=hidden_size, + num_attention_heads=num_heads, + head_dim=head_dim, + patch_length=32, + context_length=512, + horizon_length=horizon_length, + quantiles=quantiles, + use_positional_embedding=False, + use_resolution_embeddings="multi_resolution.weight" in original_sd, + use_special_token="special_token" in original_sd, + agg_factor=60, + max_position_embeddings=1025, + ) + + +def write_model(output_dir: str, huggingface_repo_id: str, safe_serialization: bool = True) -> None: + os.makedirs(output_dir, exist_ok=True) + local_dir = snapshot_download(repo_id=huggingface_repo_id, allow_patterns=[CTSM_CHECKPOINT_FILENAME]) + checkpoint_path = os.path.join(local_dir, CTSM_CHECKPOINT_FILENAME) + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"{CTSM_CHECKPOINT_FILENAME} not found in {huggingface_repo_id}") + + print(f"Loading original checkpoint from {checkpoint_path}") + original_sd = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + + config = _infer_config_from_state_dict(original_sd) + print( + f"Inferred CtsmConfig: layers={config.num_hidden_layers} hidden={config.hidden_size} " + f"heads={config.num_attention_heads} head_dim={config.head_dim} quantiles={len(config.quantiles)}" + ) + config.save_pretrained(output_dir) + + model = CtsmModelForPrediction(config) + converted_sd = convert_state_dict(original_sd, hidden_size=config.hidden_size) + + incompatible = model.load_state_dict(converted_sd, strict=False) + if incompatible.missing_keys: + print(f"[warn] missing keys after load: {incompatible.missing_keys[:10]}") + if incompatible.unexpected_keys: + print(f"[warn] unexpected keys after load: {incompatible.unexpected_keys[:10]}") + + model.save_pretrained(output_dir, safe_serialization=safe_serialization) + print(f"Saved transformers checkpoint to {output_dir}") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output_dir", required=True, help="Where to write the converted HF checkpoint.") + parser.add_argument( + "--huggingface_repo_id", + default="cisco-ai/cisco-time-series-model-1.0", + help="Original CTSM repo on the Hub.", + ) + parser.add_argument("--safe_serialization", type=bool, default=True) + args = parser.parse_args() + + write_model( + output_dir=args.output_dir, + huggingface_repo_id=args.huggingface_repo_id, + safe_serialization=args.safe_serialization, + ) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/ctsm/modeling_ctsm.py b/src/transformers/models/ctsm/modeling_ctsm.py new file mode 100644 index 000000000000..411b6dca5c91 --- /dev/null +++ b/src/transformers/models/ctsm/modeling_ctsm.py @@ -0,0 +1,1097 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/ctsm/modular_ctsm.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_ctsm.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ... import initialization as init +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub +from ...modeling_outputs import BaseModelOutput +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from .configuration_ctsm import CtsmConfig + + +@dataclass +@auto_docstring +class CtsmOutput(BaseModelOutput): + r""" + loc_coarse (`torch.Tensor` of shape `(batch_size,)`): + Per-stream mean used to normalize the coarse-resolution context. + scale_coarse (`torch.Tensor` of shape `(batch_size,)`): + Per-stream standard deviation used to normalize the coarse-resolution context. + num_coarse_patches (`int`): + Number of patches in the coarse-resolution block of the concatenated sequence. + num_fine_patches (`int`): + Number of patches in the fine-resolution block of the concatenated sequence. + """ + + loc: torch.Tensor | None = None + scale: torch.Tensor | None = None + + loc_coarse: torch.Tensor | None = None + scale_coarse: torch.Tensor | None = None + num_coarse_patches: int | None = None + num_fine_patches: int | None = None + + +@dataclass +@auto_docstring +class CtsmOutputForPrediction(BaseModelOutput): + r""" + mean_predictions (`torch.Tensor` of shape `(batch_size, horizon_length)`): + Point forecasts over the fine-resolution horizon. + full_predictions (`torch.Tensor` of shape `(batch_size, horizon_length, 1 + num_quantiles)`): + Concatenation of the mean prediction and the quantile predictions along the last axis. + """ + + mean_predictions: torch.Tensor | None = None + full_predictions: torch.Tensor | None = None + loss: torch.Tensor | float | None = None + + +class CtsmResidualBlock(nn.Module): + """Ctsm residual block.""" + + def __init__(self, input_dims, hidden_dims, output_dims): + super().__init__() + self.input_dims = input_dims + self.hidden_dims = hidden_dims + self.output_dims = output_dims + + self.input_layer = nn.Linear(input_dims, hidden_dims) + self.activation = nn.SiLU() + self.output_layer = nn.Linear(hidden_dims, output_dims) + self.residual_layer = nn.Linear(input_dims, output_dims) + + def forward(self, x): + hidden = self.input_layer(x) + hidden = self.activation(hidden) + output = self.output_layer(hidden) + residual = self.residual_layer(x) + return output + residual + + +class CtsmRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: CtsmConfig, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + + @staticmethod + def compute_default_rope_parameters( + config: CtsmConfig | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def simple_eager_attention_forward( + module: nn.Module, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float | int = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@use_kernel_func_from_hub("rotary_pos_emb") +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class CtsmAttention(nn.Module): + """TimesFM 2.0 style attention with learnable per-dimension Q scaling and rotary position embeddings.""" + + def __init__(self, config: CtsmConfig, layer_idx: int): + super().__init__() + self.config = config + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.layer_idx = layer_idx + + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_dim = config.head_dim + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_heads * self.head_dim + self.scaling = nn.Parameter(torch.empty((self.head_dim,))) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim) + self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim) + self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size) + + def _scale_query(self, query: torch.Tensor) -> torch.Tensor: + scale = F.softplus(self.scaling).mul(1.442695041 / math.sqrt(self.head_dim)) + return query * scale[None, None, None, :] + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + query_states = self._scale_query(query_states) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, simple_eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=1.0, + **kwargs, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class CtsmMLP(nn.Module): + """Pax MLP in pytorch.""" + + def __init__(self, config: CtsmConfig): + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + + self.gate_proj = nn.Linear(hidden_size, intermediate_size) + self.down_proj = nn.Linear(intermediate_size, hidden_size) + self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6) + + def forward(self, x, paddings=None): + gate_inp = self.layer_norm(x) + gate = self.gate_proj(gate_inp) + gate = F.relu(gate) + outputs = self.down_proj(gate) + if paddings is not None: + outputs = outputs * (1.0 - paddings[:, :, None]) + return outputs + x + + +@use_kernel_forward_from_hub("RMSNorm") +class CtsmRMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + CtsmRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class CtsmDecoderLayer(nn.Module): + """CTSM transformer block: attention with RoPE followed by TimesFM 2.0 MLP with padding masking.""" + + def __init__(self, config: CtsmConfig, layer_idx: int): + super().__init__() + self.self_attn = CtsmAttention(config, layer_idx=layer_idx) + self.mlp = CtsmMLP(config) + self.input_layernorm = CtsmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + paddings: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + ) + hidden_states = residual + hidden_states + hidden_states = self.mlp(hidden_states, paddings=paddings) + return hidden_states + + +class CtsmPositionalEmbedding(nn.Module): + """Generates position embedding for a given 1-d sequence.""" + + def __init__(self, config: CtsmConfig): + super().__init__() + min_timescale = config.min_timescale + max_timescale = config.max_timescale + self.min_timescale, self.max_timescale = min_timescale, max_timescale + self.embedding_dims = config.hidden_size + + num_timescales = self.embedding_dims // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1) + self.register_buffer( + "inv_timescales", + min_timescale * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment), + ) + + def forward(self, seq_length=None, position=None): + """Generates a Tensor of sinusoids with different frequencies. + + Args: + seq_length: an optional Python int defining the output sequence length. + if the `position` argument is specified. + position: [B, seq_length], optional position for each token in the + sequence, only required when the sequence is packed. + + Returns: + [B, seqlen, D] if `position` is specified, else [1, seqlen, D] + """ + if position is None and seq_length is None: + raise ValueError("Either position or seq_length must be provided") + + if position is None: + # [1, seqlen] + position = torch.arange(seq_length, dtype=torch.float32, device=self.inv_timescales.device).unsqueeze(0) + elif position.ndim != 2: + raise ValueError(f"position must be 2-dimensional, got shape {position.shape}") + + scaled_time = position.view(*position.shape, 1) * self.inv_timescales.view(1, 1, -1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) + + # Padding to ensure correct embedding dimension + signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) + return signal + + +@auto_docstring +class CtsmPreTrainedModel(PreTrainedModel): + config: CtsmConfig + base_model_prefix = "model" + _no_split_modules = ["CtsmDecoderLayer"] + main_input_name = "past_values" + input_modalities = ("time",) + _supports_sdpa = True + _can_record_outputs = { + "hidden_states": CtsmDecoderLayer, + "attentions": CtsmAttention, + } + _supports_flash_attn = True + _supports_flex_attn = True + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, CtsmAttention): + # Initialize scaling parameter + init.ones_(module.scaling) + elif isinstance(module, CtsmPositionalEmbedding): + num_timescales = module.embedding_dims // 2 + max_timescale, min_timescale = module.max_timescale, module.min_timescale + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max( + num_timescales - 1, 1 + ) + init.copy_( + module.inv_timescales, + min_timescale + * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment), + ) + if isinstance(module, CtsmModel) and getattr(module, "special_token", None) is not None: + init.normal_(module.special_token, mean=0.0, std=self.config.initializer_range) + + +def _convert_paddings_to_attention_bias(paddings: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + """Convert a `[B, N]` padding mask (1.0 = padded) to a `[B, 1, 1, N]` additive bias.""" + min_value = torch.finfo(dtype).min + return (paddings.to(dtype) * min_value).view(paddings.shape[0], 1, 1, paddings.shape[1]) + + +@auto_docstring +class CtsmModel(CtsmPreTrainedModel): + r""" + The multi-resolution CTSM encoder. The forward pass consumes two aligned streams (a coarse low-frequency + context and a fine high-frequency context), concatenates them along the sequence dimension with an + optional learned special token, and runs a stack of rotary-attention transformer layers. Attention is + bidirectional within the coarse block and causal elsewhere. + """ + + def __init__(self, config: CtsmConfig): + super().__init__(config) + + self.config = config + self.input_ff_layer = CtsmResidualBlock( + input_dims=2 * config.patch_length, + output_dims=config.hidden_size, + hidden_dims=config.intermediate_size, + ) + self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.hidden_size) + self.layers = nn.ModuleList( + [CtsmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + if self.config.use_positional_embedding: + self.position_emb = CtsmPositionalEmbedding(config=config) + + if hasattr(self, "position_emb"): + del self.position_emb + + self.rotary_emb = CtsmRotaryEmbedding(config) + + if config.use_resolution_embeddings: + self.multi_resolution = nn.Embedding(config.num_resolutions, config.hidden_size) + + if config.use_special_token: + self.special_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + + # Initialize weights and apply final processing + self.post_init() + + def _forward_transform( + self, inputs: torch.Tensor, patched_pads: torch.Tensor + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """Input is of shape [B, N, P].""" + mu, sigma = self._ctsm_masked_mean_std(inputs, patched_pads) + sigma = torch.clamp(sigma, min=self.config.tolerance) + + # Normalize each patch + outputs = (inputs - mu[:, None, None]) / sigma[:, None, None] + outputs = torch.where( + torch.abs(inputs - self.config.pad_val) < self.config.tolerance, + torch.tensor(self.config.pad_val, dtype=outputs.dtype, device=outputs.device), + outputs, + ) + return outputs, (mu, sigma) + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + past_values_coarse: torch.Tensor, + past_values_fine: torch.Tensor, + past_values_coarse_padding: torch.LongTensor | None = None, + past_values_fine_padding: torch.LongTensor | None = None, + freq: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> CtsmOutput: + r""" + past_values_coarse (`torch.FloatTensor` of shape `(batch_size, coarse_length)`): + Coarse-resolution context (e.g. hourly aggregates). Length must be a multiple of `patch_length` or + will be left-padded to one. + past_values_fine (`torch.FloatTensor` of shape `(batch_size, fine_length)`): + Fine-resolution context (e.g. minute-level). Length must be a multiple of `patch_length` or will be + left-padded to one. + past_values_coarse_padding (`torch.LongTensor`, *optional*): + Padding mask for the coarse stream, `1.0` for padded positions and `0.0` for real values. + past_values_fine_padding (`torch.LongTensor`, *optional*): + Padding mask for the fine stream. + freq (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): + Frequency indices. Defaults to all zeros. + """ + if past_values_coarse_padding is None: + past_values_coarse_padding = torch.zeros_like(past_values_coarse) + if past_values_fine_padding is None: + past_values_fine_padding = torch.zeros_like(past_values_fine) + past_values_coarse_padding = past_values_coarse_padding.to(past_values_coarse.dtype) + past_values_fine_padding = past_values_fine_padding.to(past_values_fine.dtype) + + patch_length = self.config.patch_length + past_values_coarse, past_values_coarse_padding = self._left_pad_to_patch_boundary( + past_values_coarse, past_values_coarse_padding, patch_length + ) + past_values_fine, past_values_fine_padding = self._left_pad_to_patch_boundary( + past_values_fine, past_values_fine_padding, patch_length + ) + + coarse_embeddings, coarse_patch_padding, stats_coarse = self._patchify_and_normalize( + past_values_coarse, past_values_coarse_padding + ) + fine_embeddings, fine_patch_padding, stats_fine = self._patchify_and_normalize( + past_values_fine, past_values_fine_padding + ) + + bsize, num_coarse_patches, hidden_size = coarse_embeddings.shape + num_fine_patches = fine_embeddings.shape[1] + device = coarse_embeddings.device + dtype = coarse_embeddings.dtype + + if self.config.use_special_token: + special = self.special_token.to(device=device, dtype=dtype).expand(bsize, 1, hidden_size) + special_padding = torch.zeros(bsize, 1, device=device, dtype=coarse_patch_padding.dtype) + model_input = torch.cat([coarse_embeddings, special, fine_embeddings], dim=1) + patch_padding = torch.cat([coarse_patch_padding, special_padding, fine_patch_padding], dim=1) + num_special = 1 + else: + model_input = torch.cat([coarse_embeddings, fine_embeddings], dim=1) + patch_padding = torch.cat([coarse_patch_padding, fine_patch_padding], dim=1) + num_special = 0 + + if self.config.use_resolution_embeddings: + mr_coarse = torch.zeros(num_coarse_patches, dtype=torch.long, device=device) + mr_special = torch.full((num_special,), 1, dtype=torch.long, device=device) + mr_fine = torch.full((num_fine_patches,), 2, dtype=torch.long, device=device) + mr_idx = torch.cat([mr_coarse, mr_special, mr_fine], dim=0).unsqueeze(0).expand(bsize, -1) + model_input = model_input + self.multi_resolution(mr_idx) + + if freq is None: + freq = torch.zeros((bsize, 1), dtype=torch.long, device=device) + else: + freq = freq.to(device=device, dtype=torch.long) + model_input = model_input + self.freq_emb(freq) + + attention_mask = self._build_attention_mask(patch_padding, num_coarse_patches, model_input.dtype) + position_ids = ( + torch.arange(model_input.shape[1], device=device, dtype=torch.long).unsqueeze(0).expand(bsize, -1) + ) + position_embeddings = self.rotary_emb(model_input, position_ids) + + hidden_states = model_input + for layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = layer( + hidden_states, + attention_mask=attention_mask, + paddings=patch_padding, + position_embeddings=position_embeddings, + **kwargs, + ) + + return CtsmOutput( + last_hidden_state=hidden_states, + loc=stats_fine[0], + scale=stats_fine[1], + loc_coarse=stats_coarse[0], + scale_coarse=stats_coarse[1], + num_coarse_patches=num_coarse_patches + num_special, # fine block starts here + num_fine_patches=num_fine_patches, + ) + + @staticmethod + def _prepare_4d_attention_mask( + attention_mask: torch.Tensor | None, + sequence_length: int, + dtype: torch.dtype, + device: torch.device, + is_causal: bool = True, + ) -> torch.Tensor | None: + """ + Creates 4D attention mask and combines causal and padding masks if needed. + + Args: + attention_mask: Optional tensor of shape (batch_size, seq_length) containing padding mask + sequence_length: Length of the sequence + dtype: Data type of the mask + device: Device of the mask + is_causal: Whether to apply causal masking + + Returns: + 4D attention mask of shape (batch_size, 1, seq_length, seq_length) + """ + # Get minimum value for the dtype + min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min + + # Handle padding mask + if attention_mask is not None: + # Convert 2D padding mask to 4D attention mask + attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1) + attention_mask = attention_mask * min_value + + # Create causal mask if needed + if is_causal: + causal_mask = torch.triu( + torch.ones((sequence_length, sequence_length), dtype=dtype, device=device) * min_value, + diagonal=1, + ) + causal_mask = causal_mask.view(1, 1, sequence_length, sequence_length) + + # Combine with padding mask if it exists + if attention_mask is not None: + attention_mask = torch.minimum(attention_mask, causal_mask) + else: + attention_mask = causal_mask + + return attention_mask + + @staticmethod + def _ctsm_masked_mean_std(inputs: torch.Tensor, padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Calculates mean and standard deviation of `inputs` across axis 1. + + It excludes values where `padding` is 1. + + Args: + inputs: A PyTorch tensor of shape [b, n, p]. + padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1. + + Returns: + A tuple containing the mean and standard deviation. + We return the statistics of the first patch with more than three non-padded values. + """ + + # Selecting the first patch with more than 3 unpadded values. + def _get_patch_index(arr: torch.Tensor): + indices = torch.argmax((arr >= 3).to(torch.int32), dim=1) + row_sum = (arr >= 3).to(torch.int32).sum(dim=1) + return torch.where(row_sum == 0, arr.shape[1] - 1, indices) + + pad_sum = torch.sum(1 - padding, dim=2) + patch_indices = _get_patch_index(pad_sum) + bidxs = torch.arange(inputs.shape[0]) + + arr = inputs[bidxs, patch_indices, :] + pad = padding[bidxs, patch_indices, :] + + # Create a mask where padding is 0 + mask = 1 - pad + + # Calculate the number of valid elements + num_valid_elements = torch.sum(mask, dim=1) + num_valid_elements = torch.clamp(num_valid_elements, min=1.0) + + # Calculate the masked sum and mean + masked_sum = torch.sum(arr * mask, dim=1) + masked_mean = masked_sum / num_valid_elements # [b] + + # Calculate the masked variance using centered values + masked_centered_arr = (arr - masked_mean.unsqueeze(-1)) * mask + masked_var = torch.sum(masked_centered_arr**2, dim=1) / num_valid_elements + masked_var = torch.clamp(masked_var, min=0.0) + masked_std = torch.sqrt(masked_var) + + return masked_mean, masked_std + + @staticmethod + def _ctsm_shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor: + """Shifts rows of seq based on the first 0 in each row of the mask. + + Args: + mask: mask tensor of shape [B, N] + seq: seq tensor of shape [B, N, P] + + Returns: + The shifted sequence. + """ + batch_size, num_seq, feature_dim = seq.shape + + new_mask: torch.BoolTensor = mask == 0 + + # Use argmax to find the first True value in each row + indices = new_mask.to(torch.int32).argmax(dim=1) + + # Handle rows with all zeros + indices[~new_mask.any(dim=1)] = -1 + + # Create index ranges for each sequence in the batch + idx_range = torch.arange(num_seq, device=seq.device).view(1, -1, 1).expand(batch_size, -1, feature_dim) + + # Calculate shifted indices for each element in each sequence + shifted_idx = (idx_range - indices[:, None, None]) % num_seq + + # Gather values from seq using shifted indices + shifted_seq = seq.gather(1, shifted_idx) + + return shifted_seq + + @staticmethod + def _left_pad_to_patch_boundary( + values: torch.Tensor, paddings: torch.Tensor, patch_length: int + ) -> tuple[torch.Tensor, torch.Tensor]: + rem = values.shape[1] % patch_length + if rem == 0: + return values, paddings + pad_len = patch_length - rem + values_pad = torch.zeros((values.shape[0], pad_len), device=values.device, dtype=values.dtype) + paddings_pad = torch.ones((paddings.shape[0], pad_len), device=paddings.device, dtype=paddings.dtype) + return torch.cat([values_pad, values], dim=1), torch.cat([paddings_pad, paddings], dim=1) + + def _patchify_and_normalize( + self, past_values: torch.Tensor, past_values_padding: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + bsize = past_values.shape[0] + patched_inputs = past_values.view(bsize, -1, self.config.patch_length) + patched_pads = past_values_padding.view(bsize, -1, self.config.patch_length) + + patched_inputs = torch.where( + torch.abs(patched_pads - 1.0) < self.config.tolerance, + torch.tensor(0.0, dtype=patched_inputs.dtype, device=patched_inputs.device), + patched_inputs, + ) + patched_pads = torch.where( + torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance, + torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device), + patched_pads, + ) + patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads) + patched_inputs = patched_inputs * (1.0 - patched_pads) + concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1) + embeddings = self.input_ff_layer(concat_inputs) + patch_padding = torch.min(patched_pads, dim=-1)[0] + return embeddings, patch_padding, stats + + def _build_attention_mask( + self, + patch_padding: torch.Tensor, + num_coarse_patches: int, + dtype: torch.dtype, + ) -> torch.Tensor: + """Causal mask with bidirectional attention over the coarse-resolution block.""" + bsize, seq_len = patch_padding.shape + device = patch_padding.device + min_value = torch.finfo(dtype).min + + causal = torch.triu( + torch.ones((seq_len, seq_len), dtype=dtype, device=device) * min_value, + diagonal=1, + ) + if num_coarse_patches > 0: + causal[:num_coarse_patches, :num_coarse_patches] = 0.0 + causal = causal.view(1, 1, seq_len, seq_len) + + padding_bias = _convert_paddings_to_attention_bias(patch_padding, dtype) + return torch.minimum(causal, padding_bias) + + +class CtsmModelForPrediction(CtsmPreTrainedModel): + """CTSM model with a multi-resolution prediction head and autoregressive multi-resolution decoding.""" + + def __init__(self, config: CtsmConfig): + super().__init__(config) + + self.config = config + self.context_len = config.context_length + self.horizon_len = config.horizon_length + + self.model = CtsmModel(config) + num_outputs = 1 + len(config.quantiles) + self.horizon_ff_layer = CtsmResidualBlock( + input_dims=config.hidden_size, + output_dims=config.horizon_length * num_outputs, + hidden_dims=config.intermediate_size, + ) + + # Initialize weights and apply final processing + self.post_init() + + def _preprocess( + self, inputs: Sequence[torch.Tensor], freq: Sequence[int] | None = None, context_len: int | None = None + ) -> tuple[torch.Tensor, ...]: + """Pad/truncate input time series to `context_len` and build a padding mask. + + Args: + inputs: A list of 1d Tensors. Each Tensor is the context time series of a single forecast task. + freq: Optional list of frequencies (returned as a tensor when provided). + context_len: Optional context length override (defaults to `self.context_len`). + + Returns: + Tuple of (padded_inputs, padding_mask) and optionally a freq tensor. + """ + if context_len is None: + context_len = self.context_len + + input_ts, input_padding = [], [] + + for ts in inputs: + input_len = ts.shape[0] + padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device) + if input_len < context_len: + num_front_pad = context_len - input_len + ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0) + elif input_len > context_len: + ts = ts[-context_len:] + padding = padding[-(context_len + self.horizon_len) :] + + input_ts.append(ts) + input_padding.append(padding) + + result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) + if freq is not None: + result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),) + return result + + def _postprocess_output( + self, model_output: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor] + ) -> torch.Tensor: + """Postprocess output of stacked transformer.""" + + # B x N x (H.Q) + output_ts = self.horizon_ff_layer(model_output) + + # Reshape using view + b, n, _ = output_ts.shape + output_ts = output_ts.view(b, n, self.config.horizon_length, len(self.config.quantiles) + 1) + + mu, sigma = stats + return output_ts * sigma[:, None, None, None] + mu[:, None, None, None] + + def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + losses = [] + for i, q in enumerate(self.config.quantiles): + errors = targets - predictions[..., i] + loss = torch.max((q - 1) * errors, q * errors) + losses.append(loss.mean()) + return torch.stack(losses).mean() + + @can_return_tuple + @auto_docstring + def forward( + self, + past_values: Sequence[torch.Tensor] | Sequence[tuple[torch.Tensor, torch.Tensor]], + future_values: torch.Tensor | None = None, + horizon_len: int | None = None, + freq: Sequence[int] | torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> CtsmOutputForPrediction: + r""" + past_values (`Sequence[torch.Tensor]`): + Either a list of 1-D fine-resolution tensors (the coarse stream is derived by mean-aggregating over + `agg_factor` consecutive points) or a list of `(coarse, fine)` pairs if both streams are provided. + future_values (`torch.Tensor`, *optional*): + Optional fine-resolution ground truth used to compute the loss. + horizon_len (`int`, *optional*): + Number of fine-resolution steps to forecast. Defaults to `config.horizon_length`. Values larger than + `config.horizon_length` trigger autoregressive decoding. + freq (`Sequence[int]` or `torch.Tensor`, *optional*): + Frequency indices. Defaults to zeros. + """ + device = self.horizon_ff_layer.input_layer.weight.device + horizon_len = horizon_len or self.config.horizon_length + if horizon_len <= 0: + raise ValueError("horizon_len must be positive") + + output_patch_len = self.config.horizon_length + num_decode_patches = (horizon_len + output_patch_len - 1) // output_patch_len + + coarse, coarse_pad, fine, fine_pad = self._prepare_context(past_values, device=device) + bsize = coarse.shape[0] + + if freq is None: + freq_tensor = torch.zeros((bsize, 1), dtype=torch.long, device=device) + else: + freq_tensor = torch.as_tensor( + list(freq) if not isinstance(freq, torch.Tensor) else freq, dtype=torch.long, device=device + ).view(bsize, 1) + + mean_chunks: list[torch.Tensor] = [] + quant_chunks: list[torch.Tensor] = [] + remaining = horizon_len + coarse_buffer = torch.zeros((bsize, 0), dtype=torch.float32, device=device) + last_outputs: CtsmOutput | None = None + max_coarse = self.config.context_length + max_fine = self.config.context_length + agg = self.config.agg_factor + + for _ in range(num_decode_patches): + mean_patch, quant_patch, last_outputs = self._decode_step( + past_values_coarse=coarse, + past_values_fine=fine, + past_values_coarse_padding=coarse_pad, + past_values_fine_padding=fine_pad, + freq=freq_tensor, + **kwargs, + ) + take = min(remaining, output_patch_len) + mean_chunks.append(mean_patch[:, :take]) + quant_chunks.append(quant_patch[:, :take, :]) + remaining -= take + if remaining <= 0: + break + + # Append fine predictions to fine context. + fine = torch.cat([fine, mean_patch[:, :output_patch_len]], dim=1) + fine_pad = torch.cat( + [fine_pad, torch.zeros((bsize, output_patch_len), device=device, dtype=fine_pad.dtype)], dim=1 + ) + if fine.shape[1] > max_fine: + fine = fine[:, -max_fine:] + fine_pad = fine_pad[:, -max_fine:] + + # Aggregate into coarse context when enough fine samples accumulated. + coarse_buffer = torch.cat([coarse_buffer, mean_patch[:, :output_patch_len]], dim=1) + full_blocks = coarse_buffer.shape[1] // agg + if full_blocks > 0: + blocks = coarse_buffer[:, : full_blocks * agg].view(bsize, full_blocks, agg).mean(dim=2) + coarse_buffer = coarse_buffer[:, full_blocks * agg :] + coarse = torch.cat([coarse, blocks], dim=1) + coarse_pad = torch.cat( + [coarse_pad, torch.zeros((bsize, full_blocks), device=device, dtype=coarse_pad.dtype)], dim=1 + ) + if coarse.shape[1] > max_coarse: + coarse = coarse[:, -max_coarse:] + coarse_pad = coarse_pad[:, -max_coarse:] + + mean_predictions = torch.cat(mean_chunks, dim=1)[:, :horizon_len] + full_predictions = torch.cat( + [torch.cat(mean_chunks, dim=1)[:, :horizon_len, None], torch.cat(quant_chunks, dim=1)[:, :horizon_len, :]], + dim=-1, + ) + + loss = None + if future_values is not None: + target_len = min(future_values.shape[1], mean_predictions.shape[1]) + mse_loss = F.mse_loss(mean_predictions[:, :target_len], future_values[:, :target_len]) + quantile_loss = self._quantile_loss(full_predictions[:, :target_len, 1:], future_values[:, :target_len]) + loss = mse_loss + quantile_loss + + return CtsmOutputForPrediction( + last_hidden_state=last_outputs.last_hidden_state if last_outputs is not None else None, + hidden_states=last_outputs.hidden_states if last_outputs is not None else None, + attentions=last_outputs.attentions if last_outputs is not None else None, + mean_predictions=mean_predictions, + full_predictions=full_predictions, + loss=loss, + ) + + @staticmethod + def _ctsm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]: + """Calculates the moving average using PyTorch's convolution function.""" + # Pad with zeros to handle initial window positions + arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) + # Create a convolution kernel + kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size + # Apply convolution to calculate the moving average + smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze() + return [smoothed_arr, arr - smoothed_arr] + + @staticmethod + def _build_multi_resolution( + series: torch.Tensor, agg_factor: int, coarse_len: int, fine_len: int + ) -> tuple[torch.Tensor, torch.Tensor]: + """Build (coarse, fine) contexts from a 1-D fine-resolution series. + + Coarse is the mean of the last `coarse_len * agg_factor` fine samples, aligned to block boundaries. + Fine is the last `fine_len` samples. + """ + series = series.to(torch.float32).reshape(-1) + needed = coarse_len * agg_factor + raw = series[-needed:] + remainder = raw.shape[0] % agg_factor + if remainder: + raw = raw[remainder:] + if raw.numel() == 0: + coarse = series.new_empty((0,), dtype=torch.float32) + else: + coarse = raw.reshape(-1, agg_factor).mean(dim=1) + if coarse.shape[0] > coarse_len: + coarse = coarse[-coarse_len:] + fine = series[-fine_len:].to(torch.float32) + return coarse, fine + + def _prepare_context( + self, + past_values: Sequence[torch.Tensor] | Sequence[tuple[torch.Tensor, torch.Tensor]], + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + coarse_len = self.config.context_length + fine_len = self.config.context_length + agg = self.config.agg_factor + + coarse_batch = torch.zeros((len(past_values), coarse_len), dtype=torch.float32, device=device) + coarse_pad = torch.zeros_like(coarse_batch) + fine_batch = torch.zeros((len(past_values), fine_len), dtype=torch.float32, device=device) + fine_pad = torch.zeros_like(fine_batch) + + for i, item in enumerate(past_values): + if isinstance(item, (tuple, list)) and len(item) == 2: + coarse, fine = item + coarse = torch.as_tensor(coarse, dtype=torch.float32, device=device).reshape(-1) + fine = torch.as_tensor(fine, dtype=torch.float32, device=device).reshape(-1) + else: + series = torch.as_tensor(item, dtype=torch.float32, device=device).reshape(-1) + coarse, fine = self._build_multi_resolution(series, agg, coarse_len, fine_len) + + c_n = coarse.shape[0] + if c_n >= coarse_len: + coarse_batch[i] = coarse[-coarse_len:] + elif c_n > 0: + coarse_batch[i, coarse_len - c_n :] = coarse + coarse_pad[i, : coarse_len - c_n] = 1.0 + else: + coarse_pad[i] = 1.0 + + f_n = fine.shape[0] + if f_n >= fine_len: + fine_batch[i] = fine[-fine_len:] + elif f_n > 0: + fine_batch[i, fine_len - f_n :] = fine + fine_pad[i, : fine_len - f_n] = 1.0 + else: + fine_pad[i] = 1.0 + + return coarse_batch, coarse_pad, fine_batch, fine_pad + + def _decode_step( + self, + past_values_coarse: torch.Tensor, + past_values_fine: torch.Tensor, + past_values_coarse_padding: torch.Tensor, + past_values_fine_padding: torch.Tensor, + freq: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor, CtsmOutput]: + """One AR step: return (mean_patch, quantile_patch, model_outputs) at fine resolution. + + mean_patch: `[B, horizon_length]`, quantile_patch: `[B, horizon_length, num_quantiles]`, both denormalized. + """ + outputs: CtsmOutput = self.model( + past_values_coarse=past_values_coarse, + past_values_fine=past_values_fine, + past_values_coarse_padding=past_values_coarse_padding, + past_values_fine_padding=past_values_fine_padding, + freq=freq, + **kwargs, + ) + head = self.horizon_ff_layer(outputs.last_hidden_state) + bsize, total_patches, _ = head.shape + num_outputs = 1 + len(self.config.quantiles) + head = head.view(bsize, total_patches, self.config.horizon_length, num_outputs) + + # Last fine patch index in the concatenated sequence. + fine_last_idx = total_patches - 1 + fine_patch = head[:, fine_last_idx, :, :] + + loc = outputs.loc[:, None, None] + scale = outputs.scale[:, None, None] + mean_patch = fine_patch[..., 0] * scale[..., 0] + loc[..., 0] + quant_patch = fine_patch[..., 1:] * scale + loc + mean_patch = torch.nan_to_num(mean_patch, nan=0.0, posinf=0.0, neginf=0.0) + quant_patch = torch.nan_to_num(quant_patch, nan=0.0, posinf=0.0, neginf=0.0) + return mean_patch, quant_patch, outputs + + +__all__ = ["CtsmModel", "CtsmModelForPrediction", "CtsmPreTrainedModel"] diff --git a/src/transformers/models/ctsm/modular_ctsm.py b/src/transformers/models/ctsm/modular_ctsm.py new file mode 100644 index 000000000000..6bdf4465ce26 --- /dev/null +++ b/src/transformers/models/ctsm/modular_ctsm.py @@ -0,0 +1,689 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Cisco Time Series Model (CTSM).""" + +from collections.abc import Callable, Sequence +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F +from huggingface_hub.dataclasses import strict + +from ... import initialization as init +from ...modeling_rope_utils import RopeParameters +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.generic import merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from ..phi4_multimodal.modeling_phi4_multimodal import simple_eager_attention_forward +from ..timesfm.configuration_timesfm import TimesFmConfig +from ..timesfm.modeling_timesfm import ( + TimesFmAttention, + TimesFmDecoderLayer, + TimesFmModel, + TimesFmModelForPrediction, + TimesFmOutput, + TimesFmOutputForPrediction, + TimesFmPreTrainedModel, + TimesFmResidualBlock, # re-exported as CtsmResidualBlock in the generated file +) +from ..timesfm2_5.modeling_timesfm2_5 import ( + TimesFm2_5RotaryEmbedding, + apply_rotary_pos_emb, +) + + +logger = logging.get_logger(__name__) + + +@auto_docstring(checkpoint="cisco-ai/cisco-time-series-model-1.0") +@strict +class CtsmConfig(TimesFmConfig): + r""" + patch_length (`int`, *optional*, defaults to 32): + Length of one patch in the input sequence for each resolution stream. + context_length (`int`, *optional*, defaults to 512): + Length of the input context for each resolution stream. + horizon_length (`int`, *optional*, defaults to 128): + Length of the prediction horizon produced per autoregressive step. + freq_size (`int`, *optional*, defaults to 3): + Number of frequency embeddings. + tolerance (`float`, *optional*, defaults to 1e-06): + Numerical tolerance used in normalization. + pad_val (`float`, *optional*, defaults to 1123581321.0): + Sentinel value marking padded positions in the input series. + num_hidden_layers (`int`, *optional*, defaults to 25): + Number of decoder layers. + quantiles (`list[float]`, *optional*, defaults to 15 values between 0.01 and 0.99): + Quantile levels predicted by the model. + use_positional_embedding (`bool`, *optional*, defaults to `False`): + CTSM uses rotary position embeddings and does not add sinusoidal positional embeddings. + use_resolution_embeddings (`bool`, *optional*, defaults to `True`): + Whether to add a learned embedding per resolution bucket (coarse / special / fine). + use_special_token (`bool`, *optional*, defaults to `True`): + Whether to insert a learned special token between the coarse and fine streams. + num_resolutions (`int`, *optional*, defaults to 3): + Number of resolution embeddings (coarse, special token, fine). + agg_factor (`int`, *optional*, defaults to 60): + Aggregation factor between fine and coarse resolutions (e.g. 60 minutes -> 1 hour). + max_position_embeddings (`int`, *optional*, defaults to 1025): + Maximum number of patches in the concatenated sequence (coarse + special + fine). + rope_parameters (`dict`, *optional*): + Rotary position embedding parameters. Defaults to `{"rope_type": "default", "rope_theta": 10000.0}`. + + Example: + + ```python + >>> from transformers import CtsmConfig, CtsmModelForPrediction + + >>> configuration = CtsmConfig() + >>> model = CtsmModelForPrediction(configuration) + >>> configuration = model.config + ``` + """ + + model_type = "ctsm" + + num_hidden_layers: int = 25 + context_length: int = 512 + quantiles: list[float] | tuple[float, ...] = ( + 0.01, + 0.05, + 0.1, + 0.2, + 0.25, + 0.3, + 0.4, + 0.5, + 0.6, + 0.7, + 0.75, + 0.8, + 0.9, + 0.95, + 0.99, + ) + use_positional_embedding: bool = False + use_resolution_embeddings: bool = True + use_special_token: bool = True + num_resolutions: int = 3 + agg_factor: int = 60 + max_position_embeddings: int = 1025 + rope_parameters: RopeParameters | dict | None = None + + min_timescale = AttributeError() + max_timescale = AttributeError() + + +@dataclass +@auto_docstring +class CtsmOutput(TimesFmOutput): + r""" + loc_coarse (`torch.Tensor` of shape `(batch_size,)`): + Per-stream mean used to normalize the coarse-resolution context. + scale_coarse (`torch.Tensor` of shape `(batch_size,)`): + Per-stream standard deviation used to normalize the coarse-resolution context. + num_coarse_patches (`int`): + Number of patches in the coarse-resolution block of the concatenated sequence. + num_fine_patches (`int`): + Number of patches in the fine-resolution block of the concatenated sequence. + """ + + loc_coarse: torch.Tensor | None = None + scale_coarse: torch.Tensor | None = None + num_coarse_patches: int | None = None + num_fine_patches: int | None = None + + +@dataclass +@auto_docstring +class CtsmOutputForPrediction(TimesFmOutputForPrediction): + r""" + mean_predictions (`torch.Tensor` of shape `(batch_size, horizon_length)`): + Point forecasts over the fine-resolution horizon. + full_predictions (`torch.Tensor` of shape `(batch_size, horizon_length, 1 + num_quantiles)`): + Concatenation of the mean prediction and the quantile predictions along the last axis. + """ + + pass + + +class CtsmResidualBlock(TimesFmResidualBlock): + pass + + +class CtsmRotaryEmbedding(TimesFm2_5RotaryEmbedding): + pass + + +class CtsmAttention(TimesFmAttention): + """TimesFM 2.0 style attention with learnable per-dimension Q scaling and rotary position embeddings.""" + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + query_states = self._scale_query(query_states) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, simple_eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=1.0, + **kwargs, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class CtsmDecoderLayer(TimesFmDecoderLayer): + """CTSM transformer block: attention with RoPE followed by TimesFM 2.0 MLP with padding masking.""" + + def __init__(self, config: CtsmConfig, layer_idx: int): + super().__init__(config, layer_idx=layer_idx) + self.self_attn = CtsmAttention(config, layer_idx=layer_idx) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + paddings: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + ) + hidden_states = residual + hidden_states + hidden_states = self.mlp(hidden_states, paddings=paddings) + return hidden_states + + +@auto_docstring +class CtsmPreTrainedModel(TimesFmPreTrainedModel): + config: CtsmConfig + base_model_prefix = "model" + _no_split_modules = ["CtsmDecoderLayer"] + _supports_flash_attn = True + _supports_flex_attn = True + _can_record_outputs = { + "hidden_states": CtsmDecoderLayer, + "attentions": CtsmAttention, + } + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, CtsmModel) and getattr(module, "special_token", None) is not None: + init.normal_(module.special_token, mean=0.0, std=self.config.initializer_range) + + +def _convert_paddings_to_attention_bias(paddings: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + """Convert a `[B, N]` padding mask (1.0 = padded) to a `[B, 1, 1, N]` additive bias.""" + min_value = torch.finfo(dtype).min + return (paddings.to(dtype) * min_value).view(paddings.shape[0], 1, 1, paddings.shape[1]) + + +class CtsmModel(TimesFmModel): + r""" + The multi-resolution CTSM encoder. The forward pass consumes two aligned streams (a coarse low-frequency + context and a fine high-frequency context), concatenates them along the sequence dimension with an + optional learned special token, and runs a stack of rotary-attention transformer layers. Attention is + bidirectional within the coarse block and causal elsewhere. + """ + + def __init__(self, config: CtsmConfig): + super().__init__(config) + + if hasattr(self, "position_emb"): + del self.position_emb + + self.rotary_emb = CtsmRotaryEmbedding(config) + + if config.use_resolution_embeddings: + self.multi_resolution = nn.Embedding(config.num_resolutions, config.hidden_size) + + if config.use_special_token: + self.special_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + + self.post_init() + + @staticmethod + def _left_pad_to_patch_boundary( + values: torch.Tensor, paddings: torch.Tensor, patch_length: int + ) -> tuple[torch.Tensor, torch.Tensor]: + rem = values.shape[1] % patch_length + if rem == 0: + return values, paddings + pad_len = patch_length - rem + values_pad = torch.zeros((values.shape[0], pad_len), device=values.device, dtype=values.dtype) + paddings_pad = torch.ones((paddings.shape[0], pad_len), device=paddings.device, dtype=paddings.dtype) + return torch.cat([values_pad, values], dim=1), torch.cat([paddings_pad, paddings], dim=1) + + def _patchify_and_normalize( + self, past_values: torch.Tensor, past_values_padding: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + bsize = past_values.shape[0] + patched_inputs = past_values.view(bsize, -1, self.config.patch_length) + patched_pads = past_values_padding.view(bsize, -1, self.config.patch_length) + + patched_inputs = torch.where( + torch.abs(patched_pads - 1.0) < self.config.tolerance, + torch.tensor(0.0, dtype=patched_inputs.dtype, device=patched_inputs.device), + patched_inputs, + ) + patched_pads = torch.where( + torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance, + torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device), + patched_pads, + ) + patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads) + patched_inputs = patched_inputs * (1.0 - patched_pads) + concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1) + embeddings = self.input_ff_layer(concat_inputs) + patch_padding = torch.min(patched_pads, dim=-1)[0] + return embeddings, patch_padding, stats + + def _build_attention_mask( + self, + patch_padding: torch.Tensor, + num_coarse_patches: int, + dtype: torch.dtype, + ) -> torch.Tensor: + """Causal mask with bidirectional attention over the coarse-resolution block.""" + bsize, seq_len = patch_padding.shape + device = patch_padding.device + min_value = torch.finfo(dtype).min + + causal = torch.triu( + torch.ones((seq_len, seq_len), dtype=dtype, device=device) * min_value, + diagonal=1, + ) + if num_coarse_patches > 0: + causal[:num_coarse_patches, :num_coarse_patches] = 0.0 + causal = causal.view(1, 1, seq_len, seq_len) + + padding_bias = _convert_paddings_to_attention_bias(patch_padding, dtype) + return torch.minimum(causal, padding_bias) + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + past_values_coarse: torch.Tensor, + past_values_fine: torch.Tensor, + past_values_coarse_padding: torch.LongTensor | None = None, + past_values_fine_padding: torch.LongTensor | None = None, + freq: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> CtsmOutput: + r""" + past_values_coarse (`torch.FloatTensor` of shape `(batch_size, coarse_length)`): + Coarse-resolution context (e.g. hourly aggregates). Length must be a multiple of `patch_length` or + will be left-padded to one. + past_values_fine (`torch.FloatTensor` of shape `(batch_size, fine_length)`): + Fine-resolution context (e.g. minute-level). Length must be a multiple of `patch_length` or will be + left-padded to one. + past_values_coarse_padding (`torch.LongTensor`, *optional*): + Padding mask for the coarse stream, `1.0` for padded positions and `0.0` for real values. + past_values_fine_padding (`torch.LongTensor`, *optional*): + Padding mask for the fine stream. + freq (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): + Frequency indices. Defaults to all zeros. + """ + if past_values_coarse_padding is None: + past_values_coarse_padding = torch.zeros_like(past_values_coarse) + if past_values_fine_padding is None: + past_values_fine_padding = torch.zeros_like(past_values_fine) + past_values_coarse_padding = past_values_coarse_padding.to(past_values_coarse.dtype) + past_values_fine_padding = past_values_fine_padding.to(past_values_fine.dtype) + + patch_length = self.config.patch_length + past_values_coarse, past_values_coarse_padding = self._left_pad_to_patch_boundary( + past_values_coarse, past_values_coarse_padding, patch_length + ) + past_values_fine, past_values_fine_padding = self._left_pad_to_patch_boundary( + past_values_fine, past_values_fine_padding, patch_length + ) + + coarse_embeddings, coarse_patch_padding, stats_coarse = self._patchify_and_normalize( + past_values_coarse, past_values_coarse_padding + ) + fine_embeddings, fine_patch_padding, stats_fine = self._patchify_and_normalize( + past_values_fine, past_values_fine_padding + ) + + bsize, num_coarse_patches, hidden_size = coarse_embeddings.shape + num_fine_patches = fine_embeddings.shape[1] + device = coarse_embeddings.device + dtype = coarse_embeddings.dtype + + if self.config.use_special_token: + special = self.special_token.to(device=device, dtype=dtype).expand(bsize, 1, hidden_size) + special_padding = torch.zeros(bsize, 1, device=device, dtype=coarse_patch_padding.dtype) + model_input = torch.cat([coarse_embeddings, special, fine_embeddings], dim=1) + patch_padding = torch.cat([coarse_patch_padding, special_padding, fine_patch_padding], dim=1) + num_special = 1 + else: + model_input = torch.cat([coarse_embeddings, fine_embeddings], dim=1) + patch_padding = torch.cat([coarse_patch_padding, fine_patch_padding], dim=1) + num_special = 0 + + if self.config.use_resolution_embeddings: + mr_coarse = torch.zeros(num_coarse_patches, dtype=torch.long, device=device) + mr_special = torch.full((num_special,), 1, dtype=torch.long, device=device) + mr_fine = torch.full((num_fine_patches,), 2, dtype=torch.long, device=device) + mr_idx = torch.cat([mr_coarse, mr_special, mr_fine], dim=0).unsqueeze(0).expand(bsize, -1) + model_input = model_input + self.multi_resolution(mr_idx) + + if freq is None: + freq = torch.zeros((bsize, 1), dtype=torch.long, device=device) + else: + freq = freq.to(device=device, dtype=torch.long) + model_input = model_input + self.freq_emb(freq) + + attention_mask = self._build_attention_mask(patch_padding, num_coarse_patches, model_input.dtype) + position_ids = ( + torch.arange(model_input.shape[1], device=device, dtype=torch.long).unsqueeze(0).expand(bsize, -1) + ) + position_embeddings = self.rotary_emb(model_input, position_ids) + + hidden_states = model_input + for layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = layer( + hidden_states, + attention_mask=attention_mask, + paddings=patch_padding, + position_embeddings=position_embeddings, + **kwargs, + ) + + return CtsmOutput( + last_hidden_state=hidden_states, + loc=stats_fine[0], + scale=stats_fine[1], + loc_coarse=stats_coarse[0], + scale_coarse=stats_coarse[1], + num_coarse_patches=num_coarse_patches + num_special, # fine block starts here + num_fine_patches=num_fine_patches, + ) + + +class CtsmModelForPrediction(TimesFmModelForPrediction): + """CTSM model with a multi-resolution prediction head and autoregressive multi-resolution decoding.""" + + def __init__(self, config: CtsmConfig): + super().__init__(config) + del self.decoder + del self.horizon_ff_layer + + self.model = CtsmModel(config) + num_outputs = 1 + len(config.quantiles) + self.horizon_ff_layer = CtsmResidualBlock( + input_dims=config.hidden_size, + output_dims=config.horizon_length * num_outputs, + hidden_dims=config.intermediate_size, + ) + self.post_init() + + @staticmethod + def _build_multi_resolution( + series: torch.Tensor, agg_factor: int, coarse_len: int, fine_len: int + ) -> tuple[torch.Tensor, torch.Tensor]: + """Build (coarse, fine) contexts from a 1-D fine-resolution series. + + Coarse is the mean of the last `coarse_len * agg_factor` fine samples, aligned to block boundaries. + Fine is the last `fine_len` samples. + """ + series = series.to(torch.float32).reshape(-1) + needed = coarse_len * agg_factor + raw = series[-needed:] + remainder = raw.shape[0] % agg_factor + if remainder: + raw = raw[remainder:] + if raw.numel() == 0: + coarse = series.new_empty((0,), dtype=torch.float32) + else: + coarse = raw.reshape(-1, agg_factor).mean(dim=1) + if coarse.shape[0] > coarse_len: + coarse = coarse[-coarse_len:] + fine = series[-fine_len:].to(torch.float32) + return coarse, fine + + def _prepare_context( + self, + past_values: Sequence[torch.Tensor] | Sequence[tuple[torch.Tensor, torch.Tensor]], + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + coarse_len = self.config.context_length + fine_len = self.config.context_length + agg = self.config.agg_factor + + coarse_batch = torch.zeros((len(past_values), coarse_len), dtype=torch.float32, device=device) + coarse_pad = torch.zeros_like(coarse_batch) + fine_batch = torch.zeros((len(past_values), fine_len), dtype=torch.float32, device=device) + fine_pad = torch.zeros_like(fine_batch) + + for i, item in enumerate(past_values): + if isinstance(item, (tuple, list)) and len(item) == 2: + coarse, fine = item + coarse = torch.as_tensor(coarse, dtype=torch.float32, device=device).reshape(-1) + fine = torch.as_tensor(fine, dtype=torch.float32, device=device).reshape(-1) + else: + series = torch.as_tensor(item, dtype=torch.float32, device=device).reshape(-1) + coarse, fine = self._build_multi_resolution(series, agg, coarse_len, fine_len) + + c_n = coarse.shape[0] + if c_n >= coarse_len: + coarse_batch[i] = coarse[-coarse_len:] + elif c_n > 0: + coarse_batch[i, coarse_len - c_n :] = coarse + coarse_pad[i, : coarse_len - c_n] = 1.0 + else: + coarse_pad[i] = 1.0 + + f_n = fine.shape[0] + if f_n >= fine_len: + fine_batch[i] = fine[-fine_len:] + elif f_n > 0: + fine_batch[i, fine_len - f_n :] = fine + fine_pad[i, : fine_len - f_n] = 1.0 + else: + fine_pad[i] = 1.0 + + return coarse_batch, coarse_pad, fine_batch, fine_pad + + def _decode_step( + self, + past_values_coarse: torch.Tensor, + past_values_fine: torch.Tensor, + past_values_coarse_padding: torch.Tensor, + past_values_fine_padding: torch.Tensor, + freq: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor, CtsmOutput]: + """One AR step: return (mean_patch, quantile_patch, model_outputs) at fine resolution. + + mean_patch: `[B, horizon_length]`, quantile_patch: `[B, horizon_length, num_quantiles]`, both denormalized. + """ + outputs: CtsmOutput = self.model( + past_values_coarse=past_values_coarse, + past_values_fine=past_values_fine, + past_values_coarse_padding=past_values_coarse_padding, + past_values_fine_padding=past_values_fine_padding, + freq=freq, + **kwargs, + ) + head = self.horizon_ff_layer(outputs.last_hidden_state) + bsize, total_patches, _ = head.shape + num_outputs = 1 + len(self.config.quantiles) + head = head.view(bsize, total_patches, self.config.horizon_length, num_outputs) + + # Last fine patch index in the concatenated sequence. + fine_last_idx = total_patches - 1 + fine_patch = head[:, fine_last_idx, :, :] + + loc = outputs.loc[:, None, None] + scale = outputs.scale[:, None, None] + mean_patch = fine_patch[..., 0] * scale[..., 0] + loc[..., 0] + quant_patch = fine_patch[..., 1:] * scale + loc + mean_patch = torch.nan_to_num(mean_patch, nan=0.0, posinf=0.0, neginf=0.0) + quant_patch = torch.nan_to_num(quant_patch, nan=0.0, posinf=0.0, neginf=0.0) + return mean_patch, quant_patch, outputs + + @can_return_tuple + @auto_docstring + def forward( + self, + past_values: Sequence[torch.Tensor] | Sequence[tuple[torch.Tensor, torch.Tensor]], + future_values: torch.Tensor | None = None, + horizon_len: int | None = None, + freq: Sequence[int] | torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> CtsmOutputForPrediction: + r""" + past_values (`Sequence[torch.Tensor]`): + Either a list of 1-D fine-resolution tensors (the coarse stream is derived by mean-aggregating over + `agg_factor` consecutive points) or a list of `(coarse, fine)` pairs if both streams are provided. + future_values (`torch.Tensor`, *optional*): + Optional fine-resolution ground truth used to compute the loss. + horizon_len (`int`, *optional*): + Number of fine-resolution steps to forecast. Defaults to `config.horizon_length`. Values larger than + `config.horizon_length` trigger autoregressive decoding. + freq (`Sequence[int]` or `torch.Tensor`, *optional*): + Frequency indices. Defaults to zeros. + """ + device = self.horizon_ff_layer.input_layer.weight.device + horizon_len = horizon_len or self.config.horizon_length + if horizon_len <= 0: + raise ValueError("horizon_len must be positive") + + output_patch_len = self.config.horizon_length + num_decode_patches = (horizon_len + output_patch_len - 1) // output_patch_len + + coarse, coarse_pad, fine, fine_pad = self._prepare_context(past_values, device=device) + bsize = coarse.shape[0] + + if freq is None: + freq_tensor = torch.zeros((bsize, 1), dtype=torch.long, device=device) + else: + freq_tensor = torch.as_tensor( + list(freq) if not isinstance(freq, torch.Tensor) else freq, dtype=torch.long, device=device + ).view(bsize, 1) + + mean_chunks: list[torch.Tensor] = [] + quant_chunks: list[torch.Tensor] = [] + remaining = horizon_len + coarse_buffer = torch.zeros((bsize, 0), dtype=torch.float32, device=device) + last_outputs: CtsmOutput | None = None + max_coarse = self.config.context_length + max_fine = self.config.context_length + agg = self.config.agg_factor + + for _ in range(num_decode_patches): + mean_patch, quant_patch, last_outputs = self._decode_step( + past_values_coarse=coarse, + past_values_fine=fine, + past_values_coarse_padding=coarse_pad, + past_values_fine_padding=fine_pad, + freq=freq_tensor, + **kwargs, + ) + take = min(remaining, output_patch_len) + mean_chunks.append(mean_patch[:, :take]) + quant_chunks.append(quant_patch[:, :take, :]) + remaining -= take + if remaining <= 0: + break + + # Append fine predictions to fine context. + fine = torch.cat([fine, mean_patch[:, :output_patch_len]], dim=1) + fine_pad = torch.cat( + [fine_pad, torch.zeros((bsize, output_patch_len), device=device, dtype=fine_pad.dtype)], dim=1 + ) + if fine.shape[1] > max_fine: + fine = fine[:, -max_fine:] + fine_pad = fine_pad[:, -max_fine:] + + # Aggregate into coarse context when enough fine samples accumulated. + coarse_buffer = torch.cat([coarse_buffer, mean_patch[:, :output_patch_len]], dim=1) + full_blocks = coarse_buffer.shape[1] // agg + if full_blocks > 0: + blocks = coarse_buffer[:, : full_blocks * agg].view(bsize, full_blocks, agg).mean(dim=2) + coarse_buffer = coarse_buffer[:, full_blocks * agg :] + coarse = torch.cat([coarse, blocks], dim=1) + coarse_pad = torch.cat( + [coarse_pad, torch.zeros((bsize, full_blocks), device=device, dtype=coarse_pad.dtype)], dim=1 + ) + if coarse.shape[1] > max_coarse: + coarse = coarse[:, -max_coarse:] + coarse_pad = coarse_pad[:, -max_coarse:] + + mean_predictions = torch.cat(mean_chunks, dim=1)[:, :horizon_len] + full_predictions = torch.cat( + [torch.cat(mean_chunks, dim=1)[:, :horizon_len, None], torch.cat(quant_chunks, dim=1)[:, :horizon_len, :]], + dim=-1, + ) + + loss = None + if future_values is not None: + target_len = min(future_values.shape[1], mean_predictions.shape[1]) + mse_loss = F.mse_loss(mean_predictions[:, :target_len], future_values[:, :target_len]) + quantile_loss = self._quantile_loss(full_predictions[:, :target_len, 1:], future_values[:, :target_len]) + loss = mse_loss + quantile_loss + + return CtsmOutputForPrediction( + last_hidden_state=last_outputs.last_hidden_state if last_outputs is not None else None, + hidden_states=last_outputs.hidden_states if last_outputs is not None else None, + attentions=last_outputs.attentions if last_outputs is not None else None, + mean_predictions=mean_predictions, + full_predictions=full_predictions, + loss=loss, + ) + + +__all__ = [ + "CtsmConfig", + "CtsmModel", + "CtsmModelForPrediction", + "CtsmPreTrainedModel", +] diff --git a/tests/models/ctsm/__init__.py b/tests/models/ctsm/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/ctsm/test_modeling_ctsm.py b/tests/models/ctsm/test_modeling_ctsm.py new file mode 100644 index 000000000000..abda3ec19263 --- /dev/null +++ b/tests/models/ctsm/test_modeling_ctsm.py @@ -0,0 +1,268 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from transformers import CtsmConfig, is_torch_available +from transformers.testing_utils import require_flash_attn, require_torch, require_torch_accelerator, slow, torch_device + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION, ModelTesterMixin, floats_tensor + + +if is_torch_available(): + from transformers import CtsmModel, CtsmModelForPrediction + + +class CtsmModelTester: + def __init__( + self, + parent, + patch_length: int = 8, + context_length: int = 64, + horizon_length: int = 8, + num_hidden_layers: int = 2, + hidden_size: int = 32, + intermediate_size: int = 32, + head_dim: int = 16, + num_attention_heads: int = 2, + num_key_value_heads: int = 2, + rms_norm_eps: float = 1e-6, + quantiles=(0.1, 0.5, 0.9), + agg_factor: int = 4, + max_position_embeddings: int = 64, + batch_size: int = 2, + is_training: bool = True, + ): + self.parent = parent + self.patch_length = patch_length + self.context_length = context_length + self.horizon_length = horizon_length + self.num_hidden_layers = num_hidden_layers + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.head_dim = head_dim + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.rms_norm_eps = rms_norm_eps + self.quantiles = list(quantiles) + self.agg_factor = agg_factor + self.max_position_embeddings = max_position_embeddings + self.batch_size = batch_size + self.is_training = is_training + + # Total patches in the concatenated sequence (coarse + special + fine). + self.seq_length = 2 * (context_length // patch_length) + 1 + + def get_config(self): + return CtsmConfig( + patch_length=self.patch_length, + context_length=self.context_length, + horizon_length=self.horizon_length, + num_hidden_layers=self.num_hidden_layers, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + head_dim=self.head_dim, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + rms_norm_eps=self.rms_norm_eps, + quantiles=self.quantiles, + agg_factor=self.agg_factor, + max_position_embeddings=self.max_position_embeddings, + ) + + def get_pipeline_config(self): + return self.get_config() + + def prepare_config_and_inputs(self): + bsize = self.batch_size + past_values = [ + torch.tensor( + np.sin(np.linspace(0, 20, self.agg_factor * self.context_length)), + dtype=torch.float32, + device=torch_device, + ) + for _ in range(bsize) + ] + return self.get_config(), past_values + + def prepare_config_and_inputs_for_common(self): + config, past_values = self.prepare_config_and_inputs() + inputs_dict = {"past_values": past_values} + return config, inputs_dict + + +@require_torch +class CtsmModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (CtsmModelForPrediction,) if is_torch_available() else () + test_resize_embeddings = False + is_encoder_decoder = False + test_inputs_embeds = False + test_all_params_have_gradient = False + test_headmasking = False + test_pruning = False + test_missing_keys = False + test_model_parallel = False + + def setUp(self): + self.model_tester = CtsmModelTester(self) + self.config_tester = ConfigTester(self, config_class=CtsmConfig, has_text_modality=False) + + def test_create_and_run_model(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = CtsmModelForPrediction(config) + model.to(torch_device) + model.eval() + results = model(**inputs_dict) + self.assertEqual(results.mean_predictions.shape, (self.model_tester.batch_size, config.horizon_length)) + self.assertEqual( + results.full_predictions.shape, + (self.model_tester.batch_size, config.horizon_length, 1 + len(config.quantiles)), + ) + + def test_encoder_forward_matches_predict(self): + """The low-level `CtsmModel.forward` should accept the two-stream interface directly.""" + config = self.model_tester.get_config() + model = CtsmModel(config).to(torch_device).eval() + + coarse = torch.randn(self.model_tester.batch_size, config.context_length, device=torch_device) + fine = torch.randn(self.model_tester.batch_size, config.context_length, device=torch_device) + with torch.no_grad(): + out = model(past_values_coarse=coarse, past_values_fine=fine) + + coarse_patches = config.context_length // config.patch_length + fine_patches = config.context_length // config.patch_length + self.assertEqual( + out.last_hidden_state.shape, + (self.model_tester.batch_size, coarse_patches + 1 + fine_patches, config.hidden_size), + ) + self.assertEqual(out.loc.shape, (self.model_tester.batch_size,)) + self.assertEqual(out.loc_coarse.shape, (self.model_tester.batch_size,)) + + @unittest.skip(reason="CTSM uses a custom multi-resolution attention mask built internally.") + def test_sdpa_can_dispatch_on_flash(self): + pass + + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) + def test_eager_matches_sdpa_inference( + self, name, dtype, padding_side, use_attention_mask, output_attentions, enable_kernels + ): + """CTSM builds its own mask from the concatenated stream paddings; the generic harness, which + injects external attention masks and mutates QK-norm RMSNorm eps, is not compatible. We verify + eager vs. SDPA equivalence on the low-level `CtsmModel` instead.""" + if not self.all_model_classes[0]._supports_sdpa: + self.skipTest("Model does not support SDPA") + torch_dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}[dtype] + tolerance = {torch.float32: 1e-4, torch.bfloat16: 5e-3, torch.float16: 5e-3}[torch_dtype] + self._attn_kernel_equivalence("sdpa", dtype=torch_dtype, tolerance=tolerance) + + @unittest.skip(reason="Model does not have input embeddings") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="CTSM does not support gradient checkpointing in this version") + def test_gradient_checkpointing_backward_compatibility(self): + pass + + def _attn_kernel_equivalence(self, attn_implementation, dtype=torch.float32, tolerance=1e-4): + """Compare eager vs `attn_implementation` on the low-level `CtsmModel`. + + Uses the two-stream interface so we bypass the prediction-head AR loop which + adds numerical noise unrelated to the kernel choice. + """ + config = self.model_tester.get_config() + model_eager = CtsmModel._from_config(config, attn_implementation="eager") + model_eager.to(dtype=dtype, device=torch_device).eval() + + model_other = CtsmModel._from_config(config, attn_implementation=attn_implementation) + model_other.load_state_dict(model_eager.state_dict()) + model_other.to(dtype=dtype, device=torch_device).eval() + + coarse = torch.randn(self.model_tester.batch_size, config.context_length, device=torch_device, dtype=dtype) + fine = torch.randn(self.model_tester.batch_size, config.context_length, device=torch_device, dtype=dtype) + + with torch.no_grad(): + out_e = model_eager(past_values_coarse=coarse, past_values_fine=fine) + out_o = model_other(past_values_coarse=coarse, past_values_fine=fine) + + diff = (out_e.last_hidden_state - out_o.last_hidden_state).abs().max().item() + self.assertLess(diff, tolerance, f"{attn_implementation} vs eager last_hidden_state max diff: {diff:.2e}") + + def test_eager_matches_sdpa(self): + if not self.all_model_classes[0]._supports_sdpa: + self.skipTest("Model does not support SDPA") + self._attn_kernel_equivalence("sdpa", dtype=torch.float32, tolerance=1e-4) + + @require_flash_attn + @require_torch_accelerator + def test_flash_attn_2_inference_equivalence(self): + self._attn_kernel_equivalence("flash_attention_2", dtype=torch.bfloat16, tolerance=1e-2) + + def test_retain_grad_hidden_states_attentions(self): + """CTSM returns `mean_predictions` as the first tensor, not `last_hidden_state`.""" + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + config.output_attentions = self.has_attentions + if self.has_attentions: + config._attn_implementation = "eager" + + model_class = self.all_model_classes[0] + model = model_class._from_config(config, attn_implementation="eager") + model.to(torch_device) + inputs = self._prepare_for_class(inputs_dict, model_class) + outputs = model(**inputs) + + output_tensor = outputs.mean_predictions + if outputs.hidden_states is not None: + hidden_states = outputs.hidden_states[0] + hidden_states.retain_grad() + if self.has_attentions and outputs.attentions is not None: + attentions = outputs.attentions[0] + attentions.retain_grad() + + output_tensor.flatten()[0].backward(retain_graph=True) + + if outputs.hidden_states is not None: + self.assertIsNotNone(hidden_states.grad) + if self.has_attentions and outputs.attentions is not None: + self.assertIsNotNone(attentions.grad) + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + if return_labels: + batch_size = len(inputs_dict["past_values"]) + rng = random.Random(42) + inputs_dict["future_values"] = floats_tensor([batch_size, self.model_tester.horizon_length], rng=rng) + return inputs_dict + + +@require_torch +@slow +class CtsmModelIntegrationTests(unittest.TestCase): + def test_inference(self): + model = CtsmModelForPrediction.from_pretrained("cisco-ai/cisco-time-series-model-1.0").to(torch_device) + rng = np.random.default_rng(42) + series = (np.sin(np.linspace(0, 200, 512 * 60)) + 0.05 * rng.standard_normal(512 * 60)).astype(np.float32) + past_values = [torch.tensor(series, device=torch_device)] + + with torch.no_grad(): + output = model(past_values=past_values, horizon_len=128) + + self.assertEqual(output.mean_predictions.shape, (1, 128)) + self.assertEqual(output.full_predictions.shape, (1, 128, 1 + len(model.config.quantiles))) From 6a79764fe0b96136a82db8c1e0ffc3269b77cf5e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 17 Apr 2026 13:04:36 +0200 Subject: [PATCH 2/9] ctsm: use stream-level normalization (match official CTSM reference) The original CTSM reference normalizes each stream over the full non-padded context before the forward, then denormalizes the final prediction with the same stream stats. Inheriting TimesFM's first-patch normalization gives the same result mathematically (per-patch norm + denorm + stream norm + denorm is an identity over the extra factors), but sends inputs to the transformer in a different scale than what the checkpoint was trained on, and is less efficient. This replaces the per-first-patch `_forward_transform` step with a single stream-level `_normalize_with_pad` (matching `PatchedTSMultiResolutionDecoder` in the reference), returns stream stats as `CtsmOutput.loc/scale`, and lets `CtsmModelForPrediction._decode_step` denormalize in a single pass. Verified against the 250M hub checkpoint on the reference notebook datasets: cpu_util MAE model=2.11 naive_last=3.36 (~37% better) server_responsetime MAE model=0.65 naive_last=2.05 (~3x better) internet_traffic MAE model=805 naive_last=4071 (~5x better) Quantile predictions stay monotone; 95 tests still pass. --- src/transformers/models/ctsm/modeling_ctsm.py | 72 +++++++++++------- src/transformers/models/ctsm/modular_ctsm.py | 74 ++++++++++++------- 2 files changed, 96 insertions(+), 50 deletions(-) diff --git a/src/transformers/models/ctsm/modeling_ctsm.py b/src/transformers/models/ctsm/modeling_ctsm.py index 411b6dca5c91..fd1af0d4e777 100644 --- a/src/transformers/models/ctsm/modeling_ctsm.py +++ b/src/transformers/models/ctsm/modeling_ctsm.py @@ -43,12 +43,16 @@ @auto_docstring class CtsmOutput(BaseModelOutput): r""" + loc (`torch.Tensor` of shape `(batch_size,)`): + Stream-level mean used to normalize the fine-resolution context, reused to rescale the final forecast. + scale (`torch.Tensor` of shape `(batch_size,)`): + Stream-level standard deviation of the fine-resolution context. loc_coarse (`torch.Tensor` of shape `(batch_size,)`): - Per-stream mean used to normalize the coarse-resolution context. + Stream-level mean used to normalize the coarse-resolution context. scale_coarse (`torch.Tensor` of shape `(batch_size,)`): - Per-stream standard deviation used to normalize the coarse-resolution context. + Stream-level standard deviation of the coarse-resolution context. num_coarse_patches (`int`): - Number of patches in the coarse-resolution block of the concatenated sequence. + Number of patches (including the optional special token) preceding the fine-resolution block. num_fine_patches (`int`): Number of patches in the fine-resolution block of the concatenated sequence. """ @@ -540,13 +544,16 @@ def forward( past_values_fine, past_values_fine_padding, patch_length ) - coarse_embeddings, coarse_patch_padding, stats_coarse = self._patchify_and_normalize( - past_values_coarse, past_values_coarse_padding + coarse_normalized, loc_coarse, scale_coarse = self._normalize_with_pad( + past_values_coarse, past_values_coarse_padding, tolerance=self.config.tolerance ) - fine_embeddings, fine_patch_padding, stats_fine = self._patchify_and_normalize( - past_values_fine, past_values_fine_padding + fine_normalized, loc_fine, scale_fine = self._normalize_with_pad( + past_values_fine, past_values_fine_padding, tolerance=self.config.tolerance ) + coarse_embeddings, coarse_patch_padding = self._patchify(coarse_normalized, past_values_coarse_padding) + fine_embeddings, fine_patch_padding = self._patchify(fine_normalized, past_values_fine_padding) + bsize, num_coarse_patches, hidden_size = coarse_embeddings.shape num_fine_patches = fine_embeddings.shape[1] device = coarse_embeddings.device @@ -594,10 +601,10 @@ def forward( return CtsmOutput( last_hidden_state=hidden_states, - loc=stats_fine[0], - scale=stats_fine[1], - loc_coarse=stats_coarse[0], - scale_coarse=stats_coarse[1], + loc=loc_fine, + scale=scale_fine, + loc_coarse=loc_coarse, + scale_coarse=scale_coarse, num_coarse_patches=num_coarse_patches + num_special, # fine block starts here num_fine_patches=num_fine_patches, ) @@ -739,29 +746,44 @@ def _left_pad_to_patch_boundary( paddings_pad = torch.ones((paddings.shape[0], pad_len), device=paddings.device, dtype=paddings.dtype) return torch.cat([values_pad, values], dim=1), torch.cat([paddings_pad, paddings], dim=1) - def _patchify_and_normalize( + @staticmethod + def _normalize_with_pad( + context: torch.Tensor, padding: torch.Tensor, tolerance: float = 1e-8 + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Stream-level normalization that matches the original CTSM reference. + + Normalizes ``context`` using the mean and standard deviation computed over the + non-padded positions (``padding == 0``) across the whole context, rather than + TimesFM's per-first-patch statistics. The normalized tensor has padded positions + zeroed out and is clamped to a safe range. + """ + valid = 1.0 - padding + count = valid.sum(dim=1, keepdim=True).clamp_min(1.0) + mu = (context * valid).sum(dim=1, keepdim=True) / count + + seq_len_f = context.new_tensor(float(context.shape[1])) + filled = torch.where(padding.to(dtype=torch.bool), mu, context) + sigma = filled.std(dim=1, keepdim=True, unbiased=False) * torch.sqrt(seq_len_f / count) + sigma = sigma.clamp_min(1e-2) + + normalized = (context - mu) / (sigma + tolerance) + normalized = normalized * valid + normalized = normalized.clamp(-1000.0, 1000.0) + return normalized, mu.squeeze(-1), sigma.squeeze(-1) + + def _patchify( self, past_values: torch.Tensor, past_values_padding: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor]: + """Patchify an already stream-normalized stream and project through the input tokenizer.""" bsize = past_values.shape[0] patched_inputs = past_values.view(bsize, -1, self.config.patch_length) patched_pads = past_values_padding.view(bsize, -1, self.config.patch_length) - patched_inputs = torch.where( - torch.abs(patched_pads - 1.0) < self.config.tolerance, - torch.tensor(0.0, dtype=patched_inputs.dtype, device=patched_inputs.device), - patched_inputs, - ) - patched_pads = torch.where( - torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance, - torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device), - patched_pads, - ) - patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads) patched_inputs = patched_inputs * (1.0 - patched_pads) concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1) embeddings = self.input_ff_layer(concat_inputs) patch_padding = torch.min(patched_pads, dim=-1)[0] - return embeddings, patch_padding, stats + return embeddings, patch_padding def _build_attention_mask( self, diff --git a/src/transformers/models/ctsm/modular_ctsm.py b/src/transformers/models/ctsm/modular_ctsm.py index 6bdf4465ce26..ad8d63a86562 100644 --- a/src/transformers/models/ctsm/modular_ctsm.py +++ b/src/transformers/models/ctsm/modular_ctsm.py @@ -132,12 +132,16 @@ class CtsmConfig(TimesFmConfig): @auto_docstring class CtsmOutput(TimesFmOutput): r""" + loc (`torch.Tensor` of shape `(batch_size,)`): + Stream-level mean used to normalize the fine-resolution context, reused to rescale the final forecast. + scale (`torch.Tensor` of shape `(batch_size,)`): + Stream-level standard deviation of the fine-resolution context. loc_coarse (`torch.Tensor` of shape `(batch_size,)`): - Per-stream mean used to normalize the coarse-resolution context. + Stream-level mean used to normalize the coarse-resolution context. scale_coarse (`torch.Tensor` of shape `(batch_size,)`): - Per-stream standard deviation used to normalize the coarse-resolution context. + Stream-level standard deviation of the coarse-resolution context. num_coarse_patches (`int`): - Number of patches in the coarse-resolution block of the concatenated sequence. + Number of patches (including the optional special token) preceding the fine-resolution block. num_fine_patches (`int`): Number of patches in the fine-resolution block of the concatenated sequence. """ @@ -298,29 +302,44 @@ def _left_pad_to_patch_boundary( paddings_pad = torch.ones((paddings.shape[0], pad_len), device=paddings.device, dtype=paddings.dtype) return torch.cat([values_pad, values], dim=1), torch.cat([paddings_pad, paddings], dim=1) - def _patchify_and_normalize( + @staticmethod + def _normalize_with_pad( + context: torch.Tensor, padding: torch.Tensor, tolerance: float = 1e-8 + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Stream-level normalization that matches the original CTSM reference. + + Normalizes ``context`` using the mean and standard deviation computed over the + non-padded positions (``padding == 0``) across the whole context, rather than + TimesFM's per-first-patch statistics. The normalized tensor has padded positions + zeroed out and is clamped to a safe range. + """ + valid = 1.0 - padding + count = valid.sum(dim=1, keepdim=True).clamp_min(1.0) + mu = (context * valid).sum(dim=1, keepdim=True) / count + + seq_len_f = context.new_tensor(float(context.shape[1])) + filled = torch.where(padding.to(dtype=torch.bool), mu, context) + sigma = filled.std(dim=1, keepdim=True, unbiased=False) * torch.sqrt(seq_len_f / count) + sigma = sigma.clamp_min(1e-2) + + normalized = (context - mu) / (sigma + tolerance) + normalized = normalized * valid + normalized = normalized.clamp(-1000.0, 1000.0) + return normalized, mu.squeeze(-1), sigma.squeeze(-1) + + def _patchify( self, past_values: torch.Tensor, past_values_padding: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor]: + """Patchify an already stream-normalized stream and project through the input tokenizer.""" bsize = past_values.shape[0] patched_inputs = past_values.view(bsize, -1, self.config.patch_length) patched_pads = past_values_padding.view(bsize, -1, self.config.patch_length) - patched_inputs = torch.where( - torch.abs(patched_pads - 1.0) < self.config.tolerance, - torch.tensor(0.0, dtype=patched_inputs.dtype, device=patched_inputs.device), - patched_inputs, - ) - patched_pads = torch.where( - torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance, - torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device), - patched_pads, - ) - patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads) patched_inputs = patched_inputs * (1.0 - patched_pads) concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1) embeddings = self.input_ff_layer(concat_inputs) patch_padding = torch.min(patched_pads, dim=-1)[0] - return embeddings, patch_padding, stats + return embeddings, patch_padding def _build_attention_mask( self, @@ -385,12 +404,17 @@ def forward( past_values_fine, past_values_fine_padding, patch_length ) - coarse_embeddings, coarse_patch_padding, stats_coarse = self._patchify_and_normalize( - past_values_coarse, past_values_coarse_padding + coarse_normalized, loc_coarse, scale_coarse = self._normalize_with_pad( + past_values_coarse, past_values_coarse_padding, tolerance=self.config.tolerance ) - fine_embeddings, fine_patch_padding, stats_fine = self._patchify_and_normalize( - past_values_fine, past_values_fine_padding + fine_normalized, loc_fine, scale_fine = self._normalize_with_pad( + past_values_fine, past_values_fine_padding, tolerance=self.config.tolerance + ) + + coarse_embeddings, coarse_patch_padding = self._patchify( + coarse_normalized, past_values_coarse_padding ) + fine_embeddings, fine_patch_padding = self._patchify(fine_normalized, past_values_fine_padding) bsize, num_coarse_patches, hidden_size = coarse_embeddings.shape num_fine_patches = fine_embeddings.shape[1] @@ -439,10 +463,10 @@ def forward( return CtsmOutput( last_hidden_state=hidden_states, - loc=stats_fine[0], - scale=stats_fine[1], - loc_coarse=stats_coarse[0], - scale_coarse=stats_coarse[1], + loc=loc_fine, + scale=scale_fine, + loc_coarse=loc_coarse, + scale_coarse=scale_coarse, num_coarse_patches=num_coarse_patches + num_special, # fine block starts here num_fine_patches=num_fine_patches, ) From d2384be7991897e6e827ee753d91a293fd29a0f7 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 17 Apr 2026 13:28:29 +0200 Subject: [PATCH 3/9] ctsm: document why there is no KV cache MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Each AR step recomputes the full forward by design: (1) coarse attention is bidirectional, so a new coarse patch invalidates every existing coarse K/V entry — the standard `DynamicCache.update(...)` append semantics can't express that; (2) stream normalization is recomputed per step over the raw context, which shifts every patch embedding. The original reference makes the same choice explicit (`CTSMAttentionRoPE` raises NotImplementedError on cache arguments), and it matches the convention of other time-series forecasters in transformers (TimesFM, TimesFM 2.5, PatchTST, Informer, Autoformer). --- src/transformers/models/ctsm/modeling_ctsm.py | 10 +++++++++- src/transformers/models/ctsm/modular_ctsm.py | 10 +++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/ctsm/modeling_ctsm.py b/src/transformers/models/ctsm/modeling_ctsm.py index fd1af0d4e777..44bd5829e1cf 100644 --- a/src/transformers/models/ctsm/modeling_ctsm.py +++ b/src/transformers/models/ctsm/modeling_ctsm.py @@ -809,7 +809,15 @@ def _build_attention_mask( class CtsmModelForPrediction(CtsmPreTrainedModel): - """CTSM model with a multi-resolution prediction head and autoregressive multi-resolution decoding.""" + """CTSM model with a multi-resolution prediction head and autoregressive multi-resolution decoding. + + Note: there is no KV cache. Each autoregressive step recomputes the full forward because (1) the + coarse-resolution block uses bidirectional attention, so appending a new coarse patch invalidates + every existing coarse K/V entry, and (2) stream-level normalization is recomputed every step after + new predictions are appended to the raw context, which shifts every patch embedding. This matches + the original CTSM reference (`CTSMAttentionRoPE` explicitly raises on cache arguments) and the + convention of other time-series forecasters in transformers (TimesFM, PatchTST, Informer, ...). + """ def __init__(self, config: CtsmConfig): super().__init__(config) diff --git a/src/transformers/models/ctsm/modular_ctsm.py b/src/transformers/models/ctsm/modular_ctsm.py index ad8d63a86562..c989daea9372 100644 --- a/src/transformers/models/ctsm/modular_ctsm.py +++ b/src/transformers/models/ctsm/modular_ctsm.py @@ -473,7 +473,15 @@ def forward( class CtsmModelForPrediction(TimesFmModelForPrediction): - """CTSM model with a multi-resolution prediction head and autoregressive multi-resolution decoding.""" + """CTSM model with a multi-resolution prediction head and autoregressive multi-resolution decoding. + + Note: there is no KV cache. Each autoregressive step recomputes the full forward because (1) the + coarse-resolution block uses bidirectional attention, so appending a new coarse patch invalidates + every existing coarse K/V entry, and (2) stream-level normalization is recomputed every step after + new predictions are appended to the raw context, which shifts every patch embedding. This matches + the original CTSM reference (`CTSMAttentionRoPE` explicitly raises on cache arguments) and the + convention of other time-series forecasters in transformers (TimesFM, PatchTST, Informer, ...). + """ def __init__(self, config: CtsmConfig): super().__init__(config) From 247549f11272344ef326164598c0585cb8846dc3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 17 Apr 2026 13:51:22 +0200 Subject: [PATCH 4/9] ctsm: flesh out model doc from the paper Rewrite the model doc to mirror the transformers model-doc template and pull content directly from the CTSM Technical Report (arXiv:2511.19841): - Full author list verified against the arXiv author list in order. - Quoted abstract. - Architecture section distinguishing the paper's 1.0-preview (500M, 50 layers, 9 quantiles, CPT from TimesFM 2.0) from the 1.0 release checkpoint actually on the Hub (250M, 25 layers, 15 quantiles, trained from scratch, adds RoPE, bidirectional coarse attention, short-context training). - Inference section noting the AR multi-resolution decode loop and why there is no KV cache. - Two usage snippets: auto-built coarse stream, and explicit (coarse, fine) pairs. - BibTeX citation using a BibTeX-safe form for the Yuhan Song entry (the parenthetical nickname in the paper parses oddly in BibTeX). --- docs/source/en/model_doc/ctsm.md | 71 +++++++++++++++++++++++++------- 1 file changed, 56 insertions(+), 15 deletions(-) diff --git a/docs/source/en/model_doc/ctsm.md b/docs/source/en/model_doc/ctsm.md index 372a038b839e..f4053f7c42b8 100644 --- a/docs/source/en/model_doc/ctsm.md +++ b/docs/source/en/model_doc/ctsm.md @@ -27,16 +27,42 @@ rendered properly in your Markdown viewer. ## Overview -The Cisco Time Series Model (CTSM) 1.0 is a 250M-parameter decoder-only foundation model for univariate zero-shot -forecasting, proposed in [Cisco Time Series Model Technical Report](https://huggingface.co/papers/2511.19841) by -Liang Gou et al. It is architecturally inspired by [TimesFM 2.0](https://huggingface.co/google/timesfm-2.0-500m-pytorch) -and adds a multi-resolution context (a coarse stream aggregated by a configurable `agg_factor`, a learned special -token, and a fine stream), rotary position embeddings, bidirectional attention over the coarse-resolution block, -15-quantile prediction, and per-resolution learned embeddings. +The Cisco Time Series Model (CTSM) was proposed in [Cisco Time Series Model Technical Report](https://huggingface.co/papers/2511.19841) by Liang Gou, Archit Khare, Praneet Pabolu, Prachi Patel, Joseph Ross, Hercy Shen, Yuhan (Ellen) Song, Jingze Sun, Kristal Curtis, Vedant Dharnidharka, Abhinav Mathur and Hao Yang. -The checkpoint can be found at [`cisco-ai/cisco-time-series-model-1.0`](https://huggingface.co/cisco-ai/cisco-time-series-model-1.0). +CTSM is a decoder-only univariate zero-shot forecasting foundation model. Its central idea is a **multi-resolution context**: instead of consuming a single-scale history, each forecast conditions on two aligned streams — a coarse low-frequency stream (e.g. 512 hourly points) and a fine high-frequency stream (e.g. 512 minutely points), with the resolution ratio fixed to 60. A learnable **special token** separates the two streams and learned **resolution embeddings** are added to the token stream to distinguish them. The coarse stream lets the model see week-over-week structure without giving up fine-grained recent detail; as the paper puts it, "more complex multiresolution architectures would require a context length of 30,720 (30 times as long as ours) to cover the same time range." -## Usage example +The abstract from the paper is the following: + +*We introduce the Cisco Time Series Model, a univariate zero-shot forecaster. This time series foundation model is the result of a general architectural innovation to a time series model enabling it to accept multiresolution input, applied to a popular decoder-only time series model (TimesFM). The resulting multiresolution decoder-only model is trained on over 300B unique data points, with more than half coming from the observability domain. Quantitative and qualitative evaluations demonstrate that the resulting model achieves superior performance on observability datasets while retaining very similar performance on a standard general-purpose forecasting benchmark (GIFT-Eval), and suggest that the multiresolution structure enables the model to make more accurate predictions on long context input.* + +### Architecture + +The backbone follows TimesFM 2.0: patching (patch length 32) + a residual-block input tokenizer + decoder-only transformer layers with per-dimension learnable query scaling + a residual-block horizon head. CTSM adds, on top: + +- A **special token** inserted between the coarse and fine patch streams, so the input is `[coarse₁, …, coarse₁₆, SPECIAL, fine₁, …, fine₁₆]`. +- **Resolution embeddings** (3-way: coarse / special / fine) added to each token before the transformer stack. +- **Stream-level normalization**: each stream is standardized independently over its non-padded context, and the fine-stream statistics are used to rescale the forecast. +- A **frequency embedding** inherited from TimesFM, added to every token. + +The 250M **CTSM 1.0** release checkpoint additionally introduces (over the 500M `1.0-preview` described in the paper): + +- **Rotary position embeddings (RoPE)** applied to query/key inside attention. +- **Bidirectional attention over the coarse block** — tokens in the coarse segment attend both ways within that segment, while the fine segment remains causal. +- **15-quantile prediction** (levels 0.01–0.99) instead of 9. +- **Short-context training** (1/3 of training samples drawn with `|fine| ∈ [10, 511]`) for better robustness when less history is available. +- Trained from scratch (not continued pre-training from TimesFM 2.0) on ~2× more internal observability data. + +### Inference + +For horizons longer than `config.horizon_length` (128 steps), [`CtsmModelForPrediction`] runs an autoregressive multi-resolution decode loop: each step produces 128 fine-resolution predictions, the mean forecast is appended to the fine context, and every `agg_factor=60` new fine samples are mean-aggregated into a new coarse point. There is no KV cache — the coarse block's bidirectional attention and the per-step stream renormalization make the standard append-only cache unsuitable, matching both the original reference implementation and the other time-series forecasters in `transformers`. + +The checkpoint can be found at [`cisco-ai/cisco-time-series-model-1.0`](https://huggingface.co/cisco-ai/cisco-time-series-model-1.0). The original inference code is at [github.com/splunk/cisco-time-series-model](https://github.com/splunk/cisco-time-series-model). + +This model was contributed by [kashif](https://huggingface.co/kashif). + +## Usage + +Pass a list of fine-resolution time series (e.g. minute-level); the coarse stream is built automatically by mean-aggregating consecutive blocks of `config.agg_factor` points. ```python import numpy as np @@ -46,26 +72,27 @@ from transformers import CtsmModelForPrediction model = CtsmModelForPrediction.from_pretrained("cisco-ai/cisco-time-series-model-1.0", device_map="auto") -# A fine-resolution (e.g. minute-level) time series. The coarse stream is built automatically -# by mean-aggregating consecutive blocks of `config.agg_factor` points. +# ~8.5 hours of 1-minute data; the model will build a 512-hour coarse context by aggregation. series = np.sin(np.linspace(0, 200, 512 * 60)).astype(np.float32) past_values = [torch.tensor(series, device=model.device)] with torch.no_grad(): outputs = model(past_values=past_values, horizon_len=128) -point_forecast = outputs.mean_predictions # (batch, horizon_len) -quantile_forecast = outputs.full_predictions # (batch, horizon_len, 1 + num_quantiles) +point_forecast = outputs.mean_predictions # (batch, horizon_len) +quantile_forecast = outputs.full_predictions # (batch, horizon_len, 1 + num_quantiles) ``` -You can also pass `(coarse, fine)` pairs directly if you already have the coarse stream: +If you already have a coarse stream (e.g. pre-computed 1-hour roll-ups that go further back than you have 1-minute data for), pass `(coarse, fine)` pairs directly: ```python -coarse = torch.tensor(coarse_series, dtype=torch.float32) -fine = torch.tensor(fine_series, dtype=torch.float32) +coarse = torch.tensor(hourly_series, dtype=torch.float32) # up to 512 points +fine = torch.tensor(minutely_series, dtype=torch.float32) # up to 512 points outputs = model(past_values=[(coarse, fine)], horizon_len=128) ``` +For `horizon_len > 128`, the model decodes autoregressively and extends the output accordingly. + ## CtsmConfig [[autodoc]] CtsmConfig @@ -79,3 +106,17 @@ outputs = model(past_values=[(coarse, fine)], horizon_len=128) [[autodoc]] CtsmModelForPrediction - forward + +## Citation + +```bibtex +@misc{gou2025ciscotimeseriesmodel, + title={Cisco Time Series Model Technical Report}, + author={Liang Gou and Archit Khare and Praneet Pabolu and Prachi Patel and Joseph Ross and Hercy Shen and Yuhan Song and Jingze Sun and Kristal Curtis and Vedant Dharnidharka and Abhinav Mathur and Hao Yang}, + year={2025}, + eprint={2511.19841}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2511.19841} +} +``` From 289c089f7831cd3a7bcffdf136843ed66a97bf16 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 17 Apr 2026 13:58:03 +0200 Subject: [PATCH 5/9] ctsm: delegate mask construction to TimesFmModel._prepare_4d_attention_mask CtsmModel inherits from TimesFmModel, which already provides a _prepare_4d_attention_mask(attention_mask, sequence_length, dtype, device, is_causal) static method combining padding + causal into a 4D additive mask. My _build_attention_mask was re-implementing the same logic (plus a one-line bidirectional-coarse zeroing), and _convert_paddings_to_attention_bias was duplicating the padding-to-bias conversion inside it. Replace both with a call to the inherited method + the single bidirectional patch. Numerically identical (cpu_util MAE 2.1093, same as before), 95 tests still pass. --- src/transformers/models/ctsm/modeling_ctsm.py | 28 ++++++------------- src/transformers/models/ctsm/modular_ctsm.py | 28 ++++++------------- 2 files changed, 18 insertions(+), 38 deletions(-) diff --git a/src/transformers/models/ctsm/modeling_ctsm.py b/src/transformers/models/ctsm/modeling_ctsm.py index 44bd5829e1cf..a25d3b47a100 100644 --- a/src/transformers/models/ctsm/modeling_ctsm.py +++ b/src/transformers/models/ctsm/modeling_ctsm.py @@ -442,12 +442,6 @@ def _init_weights(self, module): init.normal_(module.special_token, mean=0.0, std=self.config.initializer_range) -def _convert_paddings_to_attention_bias(paddings: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: - """Convert a `[B, N]` padding mask (1.0 = padded) to a `[B, 1, 1, N]` additive bias.""" - min_value = torch.finfo(dtype).min - return (paddings.to(dtype) * min_value).view(paddings.shape[0], 1, 1, paddings.shape[1]) - - @auto_docstring class CtsmModel(CtsmPreTrainedModel): r""" @@ -791,21 +785,17 @@ def _build_attention_mask( num_coarse_patches: int, dtype: torch.dtype, ) -> torch.Tensor: - """Causal mask with bidirectional attention over the coarse-resolution block.""" - bsize, seq_len = patch_padding.shape - device = patch_padding.device - min_value = torch.finfo(dtype).min - - causal = torch.triu( - torch.ones((seq_len, seq_len), dtype=dtype, device=device) * min_value, - diagonal=1, + """Reuse TimesFM's padding+causal 4D mask, then open the coarse-coarse block to bidirectional.""" + attention_mask = self._prepare_4d_attention_mask( + attention_mask=patch_padding, + sequence_length=patch_padding.shape[1], + dtype=dtype, + device=patch_padding.device, + is_causal=True, ) if num_coarse_patches > 0: - causal[:num_coarse_patches, :num_coarse_patches] = 0.0 - causal = causal.view(1, 1, seq_len, seq_len) - - padding_bias = _convert_paddings_to_attention_bias(patch_padding, dtype) - return torch.minimum(causal, padding_bias) + attention_mask[..., :num_coarse_patches, :num_coarse_patches] = 0.0 + return attention_mask class CtsmModelForPrediction(CtsmPreTrainedModel): diff --git a/src/transformers/models/ctsm/modular_ctsm.py b/src/transformers/models/ctsm/modular_ctsm.py index c989daea9372..abbaa706245e 100644 --- a/src/transformers/models/ctsm/modular_ctsm.py +++ b/src/transformers/models/ctsm/modular_ctsm.py @@ -260,12 +260,6 @@ def _init_weights(self, module): init.normal_(module.special_token, mean=0.0, std=self.config.initializer_range) -def _convert_paddings_to_attention_bias(paddings: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: - """Convert a `[B, N]` padding mask (1.0 = padded) to a `[B, 1, 1, N]` additive bias.""" - min_value = torch.finfo(dtype).min - return (paddings.to(dtype) * min_value).view(paddings.shape[0], 1, 1, paddings.shape[1]) - - class CtsmModel(TimesFmModel): r""" The multi-resolution CTSM encoder. The forward pass consumes two aligned streams (a coarse low-frequency @@ -347,21 +341,17 @@ def _build_attention_mask( num_coarse_patches: int, dtype: torch.dtype, ) -> torch.Tensor: - """Causal mask with bidirectional attention over the coarse-resolution block.""" - bsize, seq_len = patch_padding.shape - device = patch_padding.device - min_value = torch.finfo(dtype).min - - causal = torch.triu( - torch.ones((seq_len, seq_len), dtype=dtype, device=device) * min_value, - diagonal=1, + """Reuse TimesFM's padding+causal 4D mask, then open the coarse-coarse block to bidirectional.""" + attention_mask = self._prepare_4d_attention_mask( + attention_mask=patch_padding, + sequence_length=patch_padding.shape[1], + dtype=dtype, + device=patch_padding.device, + is_causal=True, ) if num_coarse_patches > 0: - causal[:num_coarse_patches, :num_coarse_patches] = 0.0 - causal = causal.view(1, 1, seq_len, seq_len) - - padding_bias = _convert_paddings_to_attention_bias(patch_padding, dtype) - return torch.minimum(causal, padding_bias) + attention_mask[..., :num_coarse_patches, :num_coarse_patches] = 0.0 + return attention_mask @merge_with_config_defaults @capture_outputs From 085e8445e8c0567da797853a1f99ce7d3cc57113 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 17 Apr 2026 14:23:26 +0200 Subject: [PATCH 6/9] ctsm: document `loss` on CtsmOutputForPrediction CtsmOutputForPrediction inherits `loss` from TimesFmOutputForPrediction, but the @auto_docstring check requires every field of the dataclass to be documented in the class docstring. Add the missing `loss` entry and rerun the modular converter + ruff format so the generated file is in sync. --- src/transformers/models/ctsm/modeling_ctsm.py | 2 ++ src/transformers/models/ctsm/modular_ctsm.py | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/ctsm/modeling_ctsm.py b/src/transformers/models/ctsm/modeling_ctsm.py index a25d3b47a100..c98758a5e60f 100644 --- a/src/transformers/models/ctsm/modeling_ctsm.py +++ b/src/transformers/models/ctsm/modeling_ctsm.py @@ -74,6 +74,8 @@ class CtsmOutputForPrediction(BaseModelOutput): Point forecasts over the fine-resolution horizon. full_predictions (`torch.Tensor` of shape `(batch_size, horizon_length, 1 + num_quantiles)`): Concatenation of the mean prediction and the quantile predictions along the last axis. + loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `future_values` is provided): + Training loss combining MSE of the mean forecast and quantile loss when fine-resolution targets are supplied. """ mean_predictions: torch.Tensor | None = None diff --git a/src/transformers/models/ctsm/modular_ctsm.py b/src/transformers/models/ctsm/modular_ctsm.py index abbaa706245e..bc6f7879ee8d 100644 --- a/src/transformers/models/ctsm/modular_ctsm.py +++ b/src/transformers/models/ctsm/modular_ctsm.py @@ -160,6 +160,8 @@ class CtsmOutputForPrediction(TimesFmOutputForPrediction): Point forecasts over the fine-resolution horizon. full_predictions (`torch.Tensor` of shape `(batch_size, horizon_length, 1 + num_quantiles)`): Concatenation of the mean prediction and the quantile predictions along the last axis. + loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `future_values` is provided): + Training loss combining MSE of the mean forecast and quantile loss when fine-resolution targets are supplied. """ pass @@ -401,9 +403,7 @@ def forward( past_values_fine, past_values_fine_padding, tolerance=self.config.tolerance ) - coarse_embeddings, coarse_patch_padding = self._patchify( - coarse_normalized, past_values_coarse_padding - ) + coarse_embeddings, coarse_patch_padding = self._patchify(coarse_normalized, past_values_coarse_padding) fine_embeddings, fine_patch_padding = self._patchify(fine_normalized, past_values_fine_padding) bsize, num_coarse_patches, hidden_size = coarse_embeddings.shape From 1c34986abff6a8dfa7e23b7d9999f2daa7bd8ef2 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 17 Apr 2026 14:39:41 +0200 Subject: [PATCH 7/9] ctsm: add CtsmModel to IGNORE_NON_TESTED Mirrors TimesFmModel / TimesFm2_5Model: CtsmModel is the building block used by CtsmModelForPrediction, which is the only class in `all_model_classes` in the test file. Common tests exercise CtsmModel through the prediction wrapper; there is nothing to add to the test list. --- utils/check_repo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/utils/check_repo.py b/utils/check_repo.py index c4b8e44b4dd8..2738bdc06540 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -251,6 +251,7 @@ "PPDocLayoutV3Model", # Building part of bigger (tested) model "TimesFmModel", # Building part of bigger (tested) model "TimesFm2_5Model", # Building part of bigger (tested) model + "CtsmModel", # Building part of bigger (tested) model "CsmDepthDecoderForCausalLM", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest. "CsmDepthDecoderModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest. "CsmBackboneModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest. From f9298763a3f1eb13434d08f4c93124e95e3e9243 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 18 Apr 2026 10:30:54 +0200 Subject: [PATCH 8/9] ctsm: add KV cache for autoregressive decoding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit For `horizon_len > config.horizon_length`, `CtsmModelForPrediction` now reuses a `DynamicCache` across autoregressive steps: - Step 1 runs a full forward over `[coarse, special, fine]` and populates the cache with K/V per layer. - Subsequent steps feed only the four new fine patches through the stack; their Q/K/V attend to `past_key_values.update(...)`-merged K/V. - Stream normalization stats are frozen to their step-1 values so cached embeddings stay on a consistent scale; the coarse block is pinned; if the cache would outgrow `max_position_embeddings` it's discarded and rebuilt from the current raw contexts. - `use_cache: bool | None` on `CtsmModelForPrediction.forward` lets callers force the old full-recompute path if they prefer. API additions mirror Llama et al.: - `CtsmAttention.forward(..., past_key_values=None)` - `CtsmDecoderLayer.forward(..., past_key_values=None)` - `CtsmModel.forward(..., past_key_values=None, use_cache=None, cache_position=None, loc_fine=None, scale_fine=None)` — when `past_key_values` is provided, `past_values_fine` must contain only the new fine values and `loc_fine` / `scale_fine` must be supplied so normalization matches the cached state. - `CtsmOutput.past_key_values` field. Benchmarks on the 250M hub checkpoint (CPU, horizon=512, cpu_utilization): use_cache=False 521 ms MAE=2.6852 use_cache=True 400 ms MAE=2.6852 MAE is bit-identical across the three notebook datasets. Added a `test_kv_cache_matches_full_recompute` regression test that verifies step-1 predictions are exact and subsequent AR steps stay within a generous bound on the tiny random-weights tester model. --- docs/source/en/model_doc/ctsm.md | 2 +- src/transformers/models/ctsm/modeling_ctsm.py | 476 +++++++++++++----- src/transformers/models/ctsm/modular_ctsm.py | 326 ++++++++++-- tests/models/ctsm/test_modeling_ctsm.py | 26 + 4 files changed, 658 insertions(+), 172 deletions(-) diff --git a/docs/source/en/model_doc/ctsm.md b/docs/source/en/model_doc/ctsm.md index f4053f7c42b8..8d891a07f633 100644 --- a/docs/source/en/model_doc/ctsm.md +++ b/docs/source/en/model_doc/ctsm.md @@ -54,7 +54,7 @@ The 250M **CTSM 1.0** release checkpoint additionally introduces (over the 500M ### Inference -For horizons longer than `config.horizon_length` (128 steps), [`CtsmModelForPrediction`] runs an autoregressive multi-resolution decode loop: each step produces 128 fine-resolution predictions, the mean forecast is appended to the fine context, and every `agg_factor=60` new fine samples are mean-aggregated into a new coarse point. There is no KV cache — the coarse block's bidirectional attention and the per-step stream renormalization make the standard append-only cache unsuitable, matching both the original reference implementation and the other time-series forecasters in `transformers`. +For `horizon_len > config.horizon_length`, [`CtsmModelForPrediction`] runs an autoregressive multi-resolution decode loop, using a [`DynamicCache`] by default (opt out with `use_cache=False`). Each step feeds only the newly-appended fine patches through the stack and attends to cached K/V for every earlier position. Stream-normalization statistics are frozen to their step-1 values so that cached K/V remains valid; the coarse block is pinned and the cache is rebuilt if the concatenated sequence would outgrow `max_position_embeddings`. The checkpoint can be found at [`cisco-ai/cisco-time-series-model-1.0`](https://huggingface.co/cisco-ai/cisco-time-series-model-1.0). The original inference code is at [github.com/splunk/cisco-time-series-model](https://github.com/splunk/cisco-time-series-model). diff --git a/src/transformers/models/ctsm/modeling_ctsm.py b/src/transformers/models/ctsm/modeling_ctsm.py index c98758a5e60f..4d63554ff3a2 100644 --- a/src/transformers/models/ctsm/modeling_ctsm.py +++ b/src/transformers/models/ctsm/modeling_ctsm.py @@ -28,6 +28,7 @@ import torch.nn.functional as F from ... import initialization as init +from ...cache_utils import Cache, DynamicCache from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...modeling_outputs import BaseModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update @@ -55,6 +56,10 @@ class CtsmOutput(BaseModelOutput): Number of patches (including the optional special token) preceding the fine-resolution block. num_fine_patches (`int`): Number of patches in the fine-resolution block of the concatenated sequence. + past_key_values (`Cache`, *optional*): + Key/value cache for the concatenated `[coarse, special, fine]` sequence. Populated when the + caller passes `use_cache=True` (and re-used across autoregressive decode steps). Typically only + the long-horizon AR loop in [`CtsmModelForPrediction`] needs this. """ loc: torch.Tensor | None = None @@ -64,6 +69,7 @@ class CtsmOutput(BaseModelOutput): scale_coarse: torch.Tensor | None = None num_coarse_patches: int | None = None num_fine_patches: int | None = None + past_key_values: Cache | None = None @dataclass @@ -226,7 +232,12 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): class CtsmAttention(nn.Module): - """TimesFM 2.0 style attention with learnable per-dimension Q scaling and rotary position embeddings.""" + """TimesFM 2.0 style attention with learnable per-dimension Q scaling and rotary position embeddings. + + Supports an optional `past_key_values` cache so that, during long-horizon autoregressive decoding, + each step only needs to compute K/V for the newly-appended fine patches and attends to the + previously-cached K/V for every earlier position. + """ def __init__(self, config: CtsmConfig, layer_idx: int): super().__init__() @@ -257,6 +268,7 @@ def forward( hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: torch.Tensor | None = None, + past_key_values: Cache | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None]: input_shape = hidden_states.shape[:-1] @@ -271,6 +283,9 @@ def forward( query_states = self._scale_query(query_states) + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, simple_eager_attention_forward ) @@ -348,6 +363,7 @@ def forward( attention_mask: torch.Tensor, paddings: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], + past_key_values: Cache | None = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: residual = hidden_states @@ -356,6 +372,7 @@ def forward( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, + past_key_values=past_key_values, ) hidden_states = residual + hidden_states hidden_states = self.mlp(hidden_states, paddings=paddings) @@ -504,105 +521,68 @@ def _forward_transform( @auto_docstring def forward( self, - past_values_coarse: torch.Tensor, - past_values_fine: torch.Tensor, + past_values_coarse: torch.Tensor | None = None, + past_values_fine: torch.Tensor | None = None, past_values_coarse_padding: torch.LongTensor | None = None, past_values_fine_padding: torch.LongTensor | None = None, freq: torch.Tensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, + loc_fine: torch.Tensor | None = None, + scale_fine: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> CtsmOutput: r""" - past_values_coarse (`torch.FloatTensor` of shape `(batch_size, coarse_length)`): + past_values_coarse (`torch.FloatTensor` of shape `(batch_size, coarse_length)`, *optional*): Coarse-resolution context (e.g. hourly aggregates). Length must be a multiple of `patch_length` or - will be left-padded to one. + will be left-padded to one. Required when `past_key_values` is `None`. past_values_fine (`torch.FloatTensor` of shape `(batch_size, fine_length)`): - Fine-resolution context (e.g. minute-level). Length must be a multiple of `patch_length` or will be - left-padded to one. + Fine-resolution context (e.g. minute-level). In the normal / full-forward mode this is the entire + fine context; when `past_key_values` is supplied this should contain **only the new fine values** + to append — they must already be pre-normalized by the caller using `loc_fine` / `scale_fine`. past_values_coarse_padding (`torch.LongTensor`, *optional*): Padding mask for the coarse stream, `1.0` for padded positions and `0.0` for real values. past_values_fine_padding (`torch.LongTensor`, *optional*): Padding mask for the fine stream. freq (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): Frequency indices. Defaults to all zeros. + past_key_values (`Cache`, *optional*): + A [`Cache`] (typically a [`DynamicCache`]) holding K/V for the concatenated + `[coarse, special, fine_prefix]` sequence from a previous call. When supplied the model runs in + **incremental mode**: only the new fine patches are embedded, and their Q/K/V are added on top + of the cached K/V. `loc_fine` / `scale_fine` **must** also be supplied so the new fine values + are normalized on the same scale as the cached ones. + use_cache (`bool`, *optional*): + Whether to build and return a key/value cache in the `CtsmOutput`. Defaults to `False` unless + `past_key_values` is provided (in which case caching is always on). + cache_position (`torch.LongTensor` of shape `(num_new,)`, *optional*): + Absolute positions (in the full `[coarse, special, fine]` sequence) of the new fine patches. + Only used in incremental mode; defaults to `torch.arange(past_length, past_length + num_new)`. + loc_fine (`torch.Tensor` of shape `(batch_size,)`, *optional*): + Fine-stream mean used for stream normalization. Required in incremental mode. + scale_fine (`torch.Tensor` of shape `(batch_size,)`, *optional*): + Fine-stream standard deviation used for stream normalization. Required in incremental mode. """ - if past_values_coarse_padding is None: - past_values_coarse_padding = torch.zeros_like(past_values_coarse) - if past_values_fine_padding is None: - past_values_fine_padding = torch.zeros_like(past_values_fine) - past_values_coarse_padding = past_values_coarse_padding.to(past_values_coarse.dtype) - past_values_fine_padding = past_values_fine_padding.to(past_values_fine.dtype) - - patch_length = self.config.patch_length - past_values_coarse, past_values_coarse_padding = self._left_pad_to_patch_boundary( - past_values_coarse, past_values_coarse_padding, patch_length - ) - past_values_fine, past_values_fine_padding = self._left_pad_to_patch_boundary( - past_values_fine, past_values_fine_padding, patch_length - ) - - coarse_normalized, loc_coarse, scale_coarse = self._normalize_with_pad( - past_values_coarse, past_values_coarse_padding, tolerance=self.config.tolerance - ) - fine_normalized, loc_fine, scale_fine = self._normalize_with_pad( - past_values_fine, past_values_fine_padding, tolerance=self.config.tolerance - ) - - coarse_embeddings, coarse_patch_padding = self._patchify(coarse_normalized, past_values_coarse_padding) - fine_embeddings, fine_patch_padding = self._patchify(fine_normalized, past_values_fine_padding) - - bsize, num_coarse_patches, hidden_size = coarse_embeddings.shape - num_fine_patches = fine_embeddings.shape[1] - device = coarse_embeddings.device - dtype = coarse_embeddings.dtype - - if self.config.use_special_token: - special = self.special_token.to(device=device, dtype=dtype).expand(bsize, 1, hidden_size) - special_padding = torch.zeros(bsize, 1, device=device, dtype=coarse_patch_padding.dtype) - model_input = torch.cat([coarse_embeddings, special, fine_embeddings], dim=1) - patch_padding = torch.cat([coarse_patch_padding, special_padding, fine_patch_padding], dim=1) - num_special = 1 - else: - model_input = torch.cat([coarse_embeddings, fine_embeddings], dim=1) - patch_padding = torch.cat([coarse_patch_padding, fine_patch_padding], dim=1) - num_special = 0 - - if self.config.use_resolution_embeddings: - mr_coarse = torch.zeros(num_coarse_patches, dtype=torch.long, device=device) - mr_special = torch.full((num_special,), 1, dtype=torch.long, device=device) - mr_fine = torch.full((num_fine_patches,), 2, dtype=torch.long, device=device) - mr_idx = torch.cat([mr_coarse, mr_special, mr_fine], dim=0).unsqueeze(0).expand(bsize, -1) - model_input = model_input + self.multi_resolution(mr_idx) - - if freq is None: - freq = torch.zeros((bsize, 1), dtype=torch.long, device=device) - else: - freq = freq.to(device=device, dtype=torch.long) - model_input = model_input + self.freq_emb(freq) - - attention_mask = self._build_attention_mask(patch_padding, num_coarse_patches, model_input.dtype) - position_ids = ( - torch.arange(model_input.shape[1], device=device, dtype=torch.long).unsqueeze(0).expand(bsize, -1) - ) - position_embeddings = self.rotary_emb(model_input, position_ids) - - hidden_states = model_input - for layer in self.layers[: self.config.num_hidden_layers]: - hidden_states = layer( - hidden_states, - attention_mask=attention_mask, - paddings=patch_padding, - position_embeddings=position_embeddings, + if past_key_values is None: + return self._full_forward( + past_values_coarse=past_values_coarse, + past_values_fine=past_values_fine, + past_values_coarse_padding=past_values_coarse_padding, + past_values_fine_padding=past_values_fine_padding, + freq=freq, + use_cache=bool(use_cache), **kwargs, ) - - return CtsmOutput( - last_hidden_state=hidden_states, - loc=loc_fine, - scale=scale_fine, - loc_coarse=loc_coarse, - scale_coarse=scale_coarse, - num_coarse_patches=num_coarse_patches + num_special, # fine block starts here - num_fine_patches=num_fine_patches, + return self._incremental_forward( + past_values_fine=past_values_fine, + past_values_fine_padding=past_values_fine_padding, + freq=freq, + past_key_values=past_key_values, + cache_position=cache_position, + loc_fine=loc_fine, + scale_fine=scale_fine, + **kwargs, ) @staticmethod @@ -799,16 +779,201 @@ def _build_attention_mask( attention_mask[..., :num_coarse_patches, :num_coarse_patches] = 0.0 return attention_mask + def _build_incremental_attention_mask( + self, bsize: int, num_new: int, past_length: int, dtype: torch.dtype, device: torch.device + ) -> torch.Tensor: + """Mask for the incremental (cached) path: new fine Qs attend to all cached K/V plus causal within the new block.""" + min_value = torch.finfo(dtype).min + mask = torch.zeros((bsize, 1, num_new, past_length + num_new), dtype=dtype, device=device) + if num_new > 1: + causal_new = torch.triu(torch.full((num_new, num_new), min_value, dtype=dtype, device=device), diagonal=1) + mask[:, :, :, past_length:] = causal_new + return mask + + def _full_forward( + self, + past_values_coarse: torch.Tensor, + past_values_fine: torch.Tensor, + past_values_coarse_padding: torch.LongTensor | None, + past_values_fine_padding: torch.LongTensor | None, + freq: torch.Tensor | None, + use_cache: bool, + **kwargs: Unpack[TransformersKwargs], + ) -> CtsmOutput: + if past_values_coarse_padding is None: + past_values_coarse_padding = torch.zeros_like(past_values_coarse) + if past_values_fine_padding is None: + past_values_fine_padding = torch.zeros_like(past_values_fine) + past_values_coarse_padding = past_values_coarse_padding.to(past_values_coarse.dtype) + past_values_fine_padding = past_values_fine_padding.to(past_values_fine.dtype) + + patch_length = self.config.patch_length + past_values_coarse, past_values_coarse_padding = self._left_pad_to_patch_boundary( + past_values_coarse, past_values_coarse_padding, patch_length + ) + past_values_fine, past_values_fine_padding = self._left_pad_to_patch_boundary( + past_values_fine, past_values_fine_padding, patch_length + ) + + coarse_normalized, loc_coarse, scale_coarse = self._normalize_with_pad( + past_values_coarse, past_values_coarse_padding, tolerance=self.config.tolerance + ) + fine_normalized, loc_fine, scale_fine = self._normalize_with_pad( + past_values_fine, past_values_fine_padding, tolerance=self.config.tolerance + ) + + coarse_embeddings, coarse_patch_padding = self._patchify(coarse_normalized, past_values_coarse_padding) + fine_embeddings, fine_patch_padding = self._patchify(fine_normalized, past_values_fine_padding) + + bsize, num_coarse_patches, hidden_size = coarse_embeddings.shape + num_fine_patches = fine_embeddings.shape[1] + device = coarse_embeddings.device + dtype = coarse_embeddings.dtype + + if self.config.use_special_token: + special = self.special_token.to(device=device, dtype=dtype).expand(bsize, 1, hidden_size) + special_padding = torch.zeros(bsize, 1, device=device, dtype=coarse_patch_padding.dtype) + model_input = torch.cat([coarse_embeddings, special, fine_embeddings], dim=1) + patch_padding = torch.cat([coarse_patch_padding, special_padding, fine_patch_padding], dim=1) + num_special = 1 + else: + model_input = torch.cat([coarse_embeddings, fine_embeddings], dim=1) + patch_padding = torch.cat([coarse_patch_padding, fine_patch_padding], dim=1) + num_special = 0 + + if self.config.use_resolution_embeddings: + mr_coarse = torch.zeros(num_coarse_patches, dtype=torch.long, device=device) + mr_special = torch.full((num_special,), 1, dtype=torch.long, device=device) + mr_fine = torch.full((num_fine_patches,), 2, dtype=torch.long, device=device) + mr_idx = torch.cat([mr_coarse, mr_special, mr_fine], dim=0).unsqueeze(0).expand(bsize, -1) + model_input = model_input + self.multi_resolution(mr_idx) + + if freq is None: + freq = torch.zeros((bsize, 1), dtype=torch.long, device=device) + else: + freq = freq.to(device=device, dtype=torch.long) + model_input = model_input + self.freq_emb(freq) + + attention_mask = self._build_attention_mask(patch_padding, num_coarse_patches, model_input.dtype) + position_ids = ( + torch.arange(model_input.shape[1], device=device, dtype=torch.long).unsqueeze(0).expand(bsize, -1) + ) + position_embeddings = self.rotary_emb(model_input, position_ids) + + past_key_values = DynamicCache() if use_cache else None + + hidden_states = model_input + for layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = layer( + hidden_states, + attention_mask=attention_mask, + paddings=patch_padding, + position_embeddings=position_embeddings, + past_key_values=past_key_values, + **kwargs, + ) + + return CtsmOutput( + last_hidden_state=hidden_states, + loc=loc_fine, + scale=scale_fine, + loc_coarse=loc_coarse, + scale_coarse=scale_coarse, + num_coarse_patches=num_coarse_patches + num_special, + num_fine_patches=num_fine_patches, + past_key_values=past_key_values, + ) + + def _incremental_forward( + self, + past_values_fine: torch.Tensor, + past_values_fine_padding: torch.LongTensor | None, + freq: torch.Tensor | None, + past_key_values: Cache, + cache_position: torch.LongTensor | None, + loc_fine: torch.Tensor | None, + scale_fine: torch.Tensor | None, + **kwargs: Unpack[TransformersKwargs], + ) -> CtsmOutput: + if loc_fine is None or scale_fine is None: + raise ValueError( + "`loc_fine` and `scale_fine` must be supplied together with `past_key_values` so that the new fine " + "values are normalized on the same scale as the cached ones." + ) + if past_values_fine.shape[1] % self.config.patch_length != 0: + raise ValueError( + f"In incremental mode `past_values_fine` length must be a multiple of `patch_length=" + f"{self.config.patch_length}`; got {past_values_fine.shape[1]}." + ) + + if past_values_fine_padding is None: + past_values_fine_padding = torch.zeros_like(past_values_fine) + past_values_fine_padding = past_values_fine_padding.to(past_values_fine.dtype) + + tol = self.config.tolerance + fine_normalized = (past_values_fine - loc_fine.unsqueeze(-1)) / (scale_fine.unsqueeze(-1) + tol) + fine_normalized = fine_normalized * (1.0 - past_values_fine_padding) + fine_normalized = fine_normalized.clamp(-1000.0, 1000.0) + + new_embeddings, new_patch_padding = self._patchify(fine_normalized, past_values_fine_padding) + bsize, num_new, _ = new_embeddings.shape + device = new_embeddings.device + dtype = new_embeddings.dtype + + if self.config.use_resolution_embeddings: + mr_idx = torch.full((bsize, num_new), 2, dtype=torch.long, device=device) + new_embeddings = new_embeddings + self.multi_resolution(mr_idx) + + if freq is None: + freq = torch.zeros((bsize, 1), dtype=torch.long, device=device) + else: + freq = freq.to(device=device, dtype=torch.long) + new_embeddings = new_embeddings + self.freq_emb(freq) + + past_length = past_key_values.get_seq_length() + if cache_position is None: + cache_position = torch.arange(past_length, past_length + num_new, dtype=torch.long, device=device) + position_ids = cache_position.unsqueeze(0).expand(bsize, -1) + position_embeddings = self.rotary_emb(new_embeddings, position_ids) + + attention_mask = self._build_incremental_attention_mask(bsize, num_new, past_length, dtype, device) + + hidden_states = new_embeddings + for layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = layer( + hidden_states, + attention_mask=attention_mask, + paddings=new_patch_padding, + position_embeddings=position_embeddings, + past_key_values=past_key_values, + **kwargs, + ) + + return CtsmOutput( + last_hidden_state=hidden_states, + loc=loc_fine, + scale=scale_fine, + num_fine_patches=num_new, + past_key_values=past_key_values, + ) + class CtsmModelForPrediction(CtsmPreTrainedModel): """CTSM model with a multi-resolution prediction head and autoregressive multi-resolution decoding. - Note: there is no KV cache. Each autoregressive step recomputes the full forward because (1) the - coarse-resolution block uses bidirectional attention, so appending a new coarse patch invalidates - every existing coarse K/V entry, and (2) stream-level normalization is recomputed every step after - new predictions are appended to the raw context, which shifts every patch embedding. This matches - the original CTSM reference (`CTSMAttentionRoPE` explicitly raises on cache arguments) and the - convention of other time-series forecasters in transformers (TimesFM, PatchTST, Informer, ...). + For horizons that require autoregressive decoding (``horizon_len > config.horizon_length``) the + prediction class reuses a key/value cache across AR steps: the first step runs the full forward + and populates a [`DynamicCache`], subsequent steps feed only the newly-appended fine patches + through the stack and attend to the cached K/V for every earlier position. Two caveats, matching + how a KV cache is made to fit CTSM's architecture: + + * Stream-level normalization statistics (``loc_fine``, ``scale_fine``) are frozen to the values + computed on the first step. This is a small approximation: in the untracked reference, + statistics are recomputed after each prediction is appended; in practice the drift is small + when forecasts stay in-distribution. + * If an AR step would grow the coarse block (a new coarse patch is formed once every + ``patch_length * agg_factor / output_patch_len`` steps, i.e. ~every 15 steps at the defaults), + the cache is discarded and a full forward is run, rebuilding the cache. """ def __init__(self, config: CtsmConfig): @@ -897,6 +1062,7 @@ def forward( future_values: torch.Tensor | None = None, horizon_len: int | None = None, freq: Sequence[int] | torch.Tensor | None = None, + use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], ) -> CtsmOutputForPrediction: r""" @@ -910,6 +1076,11 @@ def forward( `config.horizon_length` trigger autoregressive decoding. freq (`Sequence[int]` or `torch.Tensor`, *optional*): Frequency indices. Defaults to zeros. + use_cache (`bool`, *optional*): + Whether to use a key/value cache across autoregressive decode steps. Defaults to `True` when + `horizon_len > config.horizon_length` (i.e. when AR decoding is needed) and `False` otherwise. + Set to `False` to force a full recompute at every AR step (matches the original reference + behaviour; slower but avoids the stream-stats-freezing approximation). """ device = self.horizon_ff_layer.input_layer.weight.device horizon_len = horizon_len or self.config.horizon_length @@ -932,21 +1103,49 @@ def forward( mean_chunks: list[torch.Tensor] = [] quant_chunks: list[torch.Tensor] = [] remaining = horizon_len - coarse_buffer = torch.zeros((bsize, 0), dtype=torch.float32, device=device) last_outputs: CtsmOutput | None = None - max_coarse = self.config.context_length max_fine = self.config.context_length + max_coarse = self.config.context_length agg = self.config.agg_factor + new_fine_patches = self.config.horizon_length // self.config.patch_length + + past_key_values: Cache | None = None + frozen_loc_fine: torch.Tensor | None = None + frozen_scale_fine: torch.Tensor | None = None + coarse_buffer = torch.zeros((bsize, 0), dtype=torch.float32, device=device) + + if use_cache is None: + use_cache = num_decode_patches > 1 + pending_new_fine: torch.Tensor | None = None + + for step_idx in range(num_decode_patches): + if past_key_values is None: + # First step (or after cache invalidation): full forward. The coarse block in the cache + # stays frozen at the initial state — only the fine block grows via subsequent incremental + # steps — which matches how KV caches work for append-only sequences. + mean_patch, quant_patch, last_outputs = self._decode_step_full( + past_values_coarse=coarse, + past_values_fine=fine, + past_values_coarse_padding=coarse_pad, + past_values_fine_padding=fine_pad, + freq=freq_tensor, + use_cache=use_cache, + **kwargs, + ) + past_key_values = last_outputs.past_key_values + frozen_loc_fine = last_outputs.loc + frozen_scale_fine = last_outputs.scale + else: + # Incremental: only the fine values newly appended last step go through the stack. + mean_patch, quant_patch, last_outputs = self._decode_step_incremental( + new_fine_values=pending_new_fine, + freq=freq_tensor, + past_key_values=past_key_values, + loc_fine=frozen_loc_fine, + scale_fine=frozen_scale_fine, + **kwargs, + ) - for _ in range(num_decode_patches): - mean_patch, quant_patch, last_outputs = self._decode_step( - past_values_coarse=coarse, - past_values_fine=fine, - past_values_coarse_padding=coarse_pad, - past_values_fine_padding=fine_pad, - freq=freq_tensor, - **kwargs, - ) take = min(remaining, output_patch_len) mean_chunks.append(mean_patch[:, :take]) quant_chunks.append(quant_patch[:, :take, :]) @@ -954,8 +1153,12 @@ def forward( if remaining <= 0: break - # Append fine predictions to fine context. - fine = torch.cat([fine, mean_patch[:, :output_patch_len]], dim=1) + new_fine = mean_patch[:, :output_patch_len] + pending_new_fine = new_fine + + # Track the raw contexts so the next full-forward (initial step or after cache + # invalidation) sees the right state. Mirrors the reference AR loop. + fine = torch.cat([fine, new_fine], dim=1) fine_pad = torch.cat( [fine_pad, torch.zeros((bsize, output_patch_len), device=device, dtype=fine_pad.dtype)], dim=1 ) @@ -963,8 +1166,7 @@ def forward( fine = fine[:, -max_fine:] fine_pad = fine_pad[:, -max_fine:] - # Aggregate into coarse context when enough fine samples accumulated. - coarse_buffer = torch.cat([coarse_buffer, mean_patch[:, :output_patch_len]], dim=1) + coarse_buffer = torch.cat([coarse_buffer, new_fine], dim=1) full_blocks = coarse_buffer.shape[1] // agg if full_blocks > 0: blocks = coarse_buffer[:, : full_blocks * agg].view(bsize, full_blocks, agg).mean(dim=2) @@ -977,6 +1179,12 @@ def forward( coarse = coarse[:, -max_coarse:] coarse_pad = coarse_pad[:, -max_coarse:] + if past_key_values is not None: + projected_len = past_key_values.get_seq_length() + new_fine_patches + if projected_len >= self.config.max_position_embeddings: + past_key_values = None + pending_new_fine = None + mean_predictions = torch.cat(mean_chunks, dim=1)[:, :horizon_len] full_predictions = torch.cat( [torch.cat(mean_chunks, dim=1)[:, :horizon_len, None], torch.cat(quant_chunks, dim=1)[:, :horizon_len, :]], @@ -1077,42 +1285,64 @@ def _prepare_context( return coarse_batch, coarse_pad, fine_batch, fine_pad - def _decode_step( + def _project_last_fine(self, outputs: CtsmOutput, last_position: int) -> tuple[torch.Tensor, torch.Tensor]: + """Project the hidden state at `last_position` through the horizon head and denormalize.""" + last_hidden = outputs.last_hidden_state[:, last_position : last_position + 1, :] + head = self.horizon_ff_layer(last_hidden) + bsize = head.shape[0] + num_outputs = 1 + len(self.config.quantiles) + head = head.view(bsize, self.config.horizon_length, num_outputs) + + loc = outputs.loc[:, None, None] + scale = outputs.scale[:, None, None] + mean_patch = head[..., 0] * scale[..., 0] + loc[..., 0] + quant_patch = head[..., 1:] * scale + loc + mean_patch = torch.nan_to_num(mean_patch, nan=0.0, posinf=0.0, neginf=0.0) + quant_patch = torch.nan_to_num(quant_patch, nan=0.0, posinf=0.0, neginf=0.0) + return mean_patch, quant_patch + + def _decode_step_full( self, past_values_coarse: torch.Tensor, past_values_fine: torch.Tensor, past_values_coarse_padding: torch.Tensor, past_values_fine_padding: torch.Tensor, freq: torch.Tensor, + use_cache: bool, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor, CtsmOutput]: - """One AR step: return (mean_patch, quantile_patch, model_outputs) at fine resolution. - - mean_patch: `[B, horizon_length]`, quantile_patch: `[B, horizon_length, num_quantiles]`, both denormalized. - """ + """Full forward through the model. If `use_cache`, the returned outputs carry a fresh cache.""" outputs: CtsmOutput = self.model( past_values_coarse=past_values_coarse, past_values_fine=past_values_fine, past_values_coarse_padding=past_values_coarse_padding, past_values_fine_padding=past_values_fine_padding, freq=freq, + use_cache=use_cache, **kwargs, ) - head = self.horizon_ff_layer(outputs.last_hidden_state) - bsize, total_patches, _ = head.shape - num_outputs = 1 + len(self.config.quantiles) - head = head.view(bsize, total_patches, self.config.horizon_length, num_outputs) - - # Last fine patch index in the concatenated sequence. - fine_last_idx = total_patches - 1 - fine_patch = head[:, fine_last_idx, :, :] + mean_patch, quant_patch = self._project_last_fine(outputs, outputs.last_hidden_state.shape[1] - 1) + return mean_patch, quant_patch, outputs - loc = outputs.loc[:, None, None] - scale = outputs.scale[:, None, None] - mean_patch = fine_patch[..., 0] * scale[..., 0] + loc[..., 0] - quant_patch = fine_patch[..., 1:] * scale + loc - mean_patch = torch.nan_to_num(mean_patch, nan=0.0, posinf=0.0, neginf=0.0) - quant_patch = torch.nan_to_num(quant_patch, nan=0.0, posinf=0.0, neginf=0.0) + def _decode_step_incremental( + self, + new_fine_values: torch.Tensor, + freq: torch.Tensor, + past_key_values: Cache, + loc_fine: torch.Tensor, + scale_fine: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor, CtsmOutput]: + """Append `new_fine_values` to the cached state and run only the new positions through the stack.""" + outputs: CtsmOutput = self.model( + past_values_fine=new_fine_values, + freq=freq, + past_key_values=past_key_values, + loc_fine=loc_fine, + scale_fine=scale_fine, + **kwargs, + ) + mean_patch, quant_patch = self._project_last_fine(outputs, outputs.last_hidden_state.shape[1] - 1) return mean_patch, quant_patch, outputs diff --git a/src/transformers/models/ctsm/modular_ctsm.py b/src/transformers/models/ctsm/modular_ctsm.py index bc6f7879ee8d..9ac80b54f7f6 100644 --- a/src/transformers/models/ctsm/modular_ctsm.py +++ b/src/transformers/models/ctsm/modular_ctsm.py @@ -22,6 +22,7 @@ from huggingface_hub.dataclasses import strict from ... import initialization as init +from ...cache_utils import Cache, DynamicCache from ...modeling_rope_utils import RopeParameters from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack @@ -144,12 +145,17 @@ class CtsmOutput(TimesFmOutput): Number of patches (including the optional special token) preceding the fine-resolution block. num_fine_patches (`int`): Number of patches in the fine-resolution block of the concatenated sequence. + past_key_values (`Cache`, *optional*): + Key/value cache for the concatenated `[coarse, special, fine]` sequence. Populated when the + caller passes `use_cache=True` (and re-used across autoregressive decode steps). Typically only + the long-horizon AR loop in [`CtsmModelForPrediction`] needs this. """ loc_coarse: torch.Tensor | None = None scale_coarse: torch.Tensor | None = None num_coarse_patches: int | None = None num_fine_patches: int | None = None + past_key_values: Cache | None = None @dataclass @@ -176,13 +182,19 @@ class CtsmRotaryEmbedding(TimesFm2_5RotaryEmbedding): class CtsmAttention(TimesFmAttention): - """TimesFM 2.0 style attention with learnable per-dimension Q scaling and rotary position embeddings.""" + """TimesFM 2.0 style attention with learnable per-dimension Q scaling and rotary position embeddings. + + Supports an optional `past_key_values` cache so that, during long-horizon autoregressive decoding, + each step only needs to compute K/V for the newly-appended fine patches and attends to the + previously-cached K/V for every earlier position. + """ def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: torch.Tensor | None = None, + past_key_values: Cache | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None]: input_shape = hidden_states.shape[:-1] @@ -197,6 +209,9 @@ def forward( query_states = self._scale_query(query_states) + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, simple_eager_attention_forward ) @@ -229,6 +244,7 @@ def forward( attention_mask: torch.Tensor, paddings: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], + past_key_values: Cache | None = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: residual = hidden_states @@ -237,6 +253,7 @@ def forward( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, + past_key_values=past_key_values, ) hidden_states = residual + hidden_states hidden_states = self.mlp(hidden_states, paddings=paddings) @@ -355,32 +372,96 @@ def _build_attention_mask( attention_mask[..., :num_coarse_patches, :num_coarse_patches] = 0.0 return attention_mask + def _build_incremental_attention_mask( + self, bsize: int, num_new: int, past_length: int, dtype: torch.dtype, device: torch.device + ) -> torch.Tensor: + """Mask for the incremental (cached) path: new fine Qs attend to all cached K/V plus causal within the new block.""" + min_value = torch.finfo(dtype).min + mask = torch.zeros((bsize, 1, num_new, past_length + num_new), dtype=dtype, device=device) + if num_new > 1: + causal_new = torch.triu(torch.full((num_new, num_new), min_value, dtype=dtype, device=device), diagonal=1) + mask[:, :, :, past_length:] = causal_new + return mask + @merge_with_config_defaults @capture_outputs @auto_docstring def forward( self, - past_values_coarse: torch.Tensor, - past_values_fine: torch.Tensor, + past_values_coarse: torch.Tensor | None = None, + past_values_fine: torch.Tensor | None = None, past_values_coarse_padding: torch.LongTensor | None = None, past_values_fine_padding: torch.LongTensor | None = None, freq: torch.Tensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, + loc_fine: torch.Tensor | None = None, + scale_fine: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> CtsmOutput: r""" - past_values_coarse (`torch.FloatTensor` of shape `(batch_size, coarse_length)`): + past_values_coarse (`torch.FloatTensor` of shape `(batch_size, coarse_length)`, *optional*): Coarse-resolution context (e.g. hourly aggregates). Length must be a multiple of `patch_length` or - will be left-padded to one. + will be left-padded to one. Required when `past_key_values` is `None`. past_values_fine (`torch.FloatTensor` of shape `(batch_size, fine_length)`): - Fine-resolution context (e.g. minute-level). Length must be a multiple of `patch_length` or will be - left-padded to one. + Fine-resolution context (e.g. minute-level). In the normal / full-forward mode this is the entire + fine context; when `past_key_values` is supplied this should contain **only the new fine values** + to append — they must already be pre-normalized by the caller using `loc_fine` / `scale_fine`. past_values_coarse_padding (`torch.LongTensor`, *optional*): Padding mask for the coarse stream, `1.0` for padded positions and `0.0` for real values. past_values_fine_padding (`torch.LongTensor`, *optional*): Padding mask for the fine stream. freq (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): Frequency indices. Defaults to all zeros. + past_key_values (`Cache`, *optional*): + A [`Cache`] (typically a [`DynamicCache`]) holding K/V for the concatenated + `[coarse, special, fine_prefix]` sequence from a previous call. When supplied the model runs in + **incremental mode**: only the new fine patches are embedded, and their Q/K/V are added on top + of the cached K/V. `loc_fine` / `scale_fine` **must** also be supplied so the new fine values + are normalized on the same scale as the cached ones. + use_cache (`bool`, *optional*): + Whether to build and return a key/value cache in the `CtsmOutput`. Defaults to `False` unless + `past_key_values` is provided (in which case caching is always on). + cache_position (`torch.LongTensor` of shape `(num_new,)`, *optional*): + Absolute positions (in the full `[coarse, special, fine]` sequence) of the new fine patches. + Only used in incremental mode; defaults to `torch.arange(past_length, past_length + num_new)`. + loc_fine (`torch.Tensor` of shape `(batch_size,)`, *optional*): + Fine-stream mean used for stream normalization. Required in incremental mode. + scale_fine (`torch.Tensor` of shape `(batch_size,)`, *optional*): + Fine-stream standard deviation used for stream normalization. Required in incremental mode. """ + if past_key_values is None: + return self._full_forward( + past_values_coarse=past_values_coarse, + past_values_fine=past_values_fine, + past_values_coarse_padding=past_values_coarse_padding, + past_values_fine_padding=past_values_fine_padding, + freq=freq, + use_cache=bool(use_cache), + **kwargs, + ) + return self._incremental_forward( + past_values_fine=past_values_fine, + past_values_fine_padding=past_values_fine_padding, + freq=freq, + past_key_values=past_key_values, + cache_position=cache_position, + loc_fine=loc_fine, + scale_fine=scale_fine, + **kwargs, + ) + + def _full_forward( + self, + past_values_coarse: torch.Tensor, + past_values_fine: torch.Tensor, + past_values_coarse_padding: torch.LongTensor | None, + past_values_fine_padding: torch.LongTensor | None, + freq: torch.Tensor | None, + use_cache: bool, + **kwargs: Unpack[TransformersKwargs], + ) -> CtsmOutput: if past_values_coarse_padding is None: past_values_coarse_padding = torch.zeros_like(past_values_coarse) if past_values_fine_padding is None: @@ -441,6 +522,8 @@ def forward( ) position_embeddings = self.rotary_emb(model_input, position_ids) + past_key_values = DynamicCache() if use_cache else None + hidden_states = model_input for layer in self.layers[: self.config.num_hidden_layers]: hidden_states = layer( @@ -448,6 +531,7 @@ def forward( attention_mask=attention_mask, paddings=patch_padding, position_embeddings=position_embeddings, + past_key_values=past_key_values, **kwargs, ) @@ -457,20 +541,101 @@ def forward( scale=scale_fine, loc_coarse=loc_coarse, scale_coarse=scale_coarse, - num_coarse_patches=num_coarse_patches + num_special, # fine block starts here + num_coarse_patches=num_coarse_patches + num_special, num_fine_patches=num_fine_patches, + past_key_values=past_key_values, + ) + + def _incremental_forward( + self, + past_values_fine: torch.Tensor, + past_values_fine_padding: torch.LongTensor | None, + freq: torch.Tensor | None, + past_key_values: Cache, + cache_position: torch.LongTensor | None, + loc_fine: torch.Tensor | None, + scale_fine: torch.Tensor | None, + **kwargs: Unpack[TransformersKwargs], + ) -> CtsmOutput: + if loc_fine is None or scale_fine is None: + raise ValueError( + "`loc_fine` and `scale_fine` must be supplied together with `past_key_values` so that the new fine " + "values are normalized on the same scale as the cached ones." + ) + if past_values_fine.shape[1] % self.config.patch_length != 0: + raise ValueError( + f"In incremental mode `past_values_fine` length must be a multiple of `patch_length=" + f"{self.config.patch_length}`; got {past_values_fine.shape[1]}." + ) + + if past_values_fine_padding is None: + past_values_fine_padding = torch.zeros_like(past_values_fine) + past_values_fine_padding = past_values_fine_padding.to(past_values_fine.dtype) + + tol = self.config.tolerance + fine_normalized = (past_values_fine - loc_fine.unsqueeze(-1)) / (scale_fine.unsqueeze(-1) + tol) + fine_normalized = fine_normalized * (1.0 - past_values_fine_padding) + fine_normalized = fine_normalized.clamp(-1000.0, 1000.0) + + new_embeddings, new_patch_padding = self._patchify(fine_normalized, past_values_fine_padding) + bsize, num_new, _ = new_embeddings.shape + device = new_embeddings.device + dtype = new_embeddings.dtype + + if self.config.use_resolution_embeddings: + mr_idx = torch.full((bsize, num_new), 2, dtype=torch.long, device=device) + new_embeddings = new_embeddings + self.multi_resolution(mr_idx) + + if freq is None: + freq = torch.zeros((bsize, 1), dtype=torch.long, device=device) + else: + freq = freq.to(device=device, dtype=torch.long) + new_embeddings = new_embeddings + self.freq_emb(freq) + + past_length = past_key_values.get_seq_length() + if cache_position is None: + cache_position = torch.arange(past_length, past_length + num_new, dtype=torch.long, device=device) + position_ids = cache_position.unsqueeze(0).expand(bsize, -1) + position_embeddings = self.rotary_emb(new_embeddings, position_ids) + + attention_mask = self._build_incremental_attention_mask(bsize, num_new, past_length, dtype, device) + + hidden_states = new_embeddings + for layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = layer( + hidden_states, + attention_mask=attention_mask, + paddings=new_patch_padding, + position_embeddings=position_embeddings, + past_key_values=past_key_values, + **kwargs, + ) + + return CtsmOutput( + last_hidden_state=hidden_states, + loc=loc_fine, + scale=scale_fine, + num_fine_patches=num_new, + past_key_values=past_key_values, ) class CtsmModelForPrediction(TimesFmModelForPrediction): """CTSM model with a multi-resolution prediction head and autoregressive multi-resolution decoding. - Note: there is no KV cache. Each autoregressive step recomputes the full forward because (1) the - coarse-resolution block uses bidirectional attention, so appending a new coarse patch invalidates - every existing coarse K/V entry, and (2) stream-level normalization is recomputed every step after - new predictions are appended to the raw context, which shifts every patch embedding. This matches - the original CTSM reference (`CTSMAttentionRoPE` explicitly raises on cache arguments) and the - convention of other time-series forecasters in transformers (TimesFM, PatchTST, Informer, ...). + For horizons that require autoregressive decoding (``horizon_len > config.horizon_length``) the + prediction class reuses a key/value cache across AR steps: the first step runs the full forward + and populates a [`DynamicCache`], subsequent steps feed only the newly-appended fine patches + through the stack and attend to the cached K/V for every earlier position. Two caveats, matching + how a KV cache is made to fit CTSM's architecture: + + * Stream-level normalization statistics (``loc_fine``, ``scale_fine``) are frozen to the values + computed on the first step. This is a small approximation: in the untracked reference, + statistics are recomputed after each prediction is appended; in practice the drift is small + when forecasts stay in-distribution. + * If an AR step would grow the coarse block (a new coarse patch is formed once every + ``patch_length * agg_factor / output_patch_len`` steps, i.e. ~every 15 steps at the defaults), + the cache is discarded and a full forward is run, rebuilding the cache. """ def __init__(self, config: CtsmConfig): @@ -554,42 +719,64 @@ def _prepare_context( return coarse_batch, coarse_pad, fine_batch, fine_pad - def _decode_step( + def _project_last_fine(self, outputs: CtsmOutput, last_position: int) -> tuple[torch.Tensor, torch.Tensor]: + """Project the hidden state at `last_position` through the horizon head and denormalize.""" + last_hidden = outputs.last_hidden_state[:, last_position : last_position + 1, :] + head = self.horizon_ff_layer(last_hidden) + bsize = head.shape[0] + num_outputs = 1 + len(self.config.quantiles) + head = head.view(bsize, self.config.horizon_length, num_outputs) + + loc = outputs.loc[:, None, None] + scale = outputs.scale[:, None, None] + mean_patch = head[..., 0] * scale[..., 0] + loc[..., 0] + quant_patch = head[..., 1:] * scale + loc + mean_patch = torch.nan_to_num(mean_patch, nan=0.0, posinf=0.0, neginf=0.0) + quant_patch = torch.nan_to_num(quant_patch, nan=0.0, posinf=0.0, neginf=0.0) + return mean_patch, quant_patch + + def _decode_step_full( self, past_values_coarse: torch.Tensor, past_values_fine: torch.Tensor, past_values_coarse_padding: torch.Tensor, past_values_fine_padding: torch.Tensor, freq: torch.Tensor, + use_cache: bool, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor, CtsmOutput]: - """One AR step: return (mean_patch, quantile_patch, model_outputs) at fine resolution. - - mean_patch: `[B, horizon_length]`, quantile_patch: `[B, horizon_length, num_quantiles]`, both denormalized. - """ + """Full forward through the model. If `use_cache`, the returned outputs carry a fresh cache.""" outputs: CtsmOutput = self.model( past_values_coarse=past_values_coarse, past_values_fine=past_values_fine, past_values_coarse_padding=past_values_coarse_padding, past_values_fine_padding=past_values_fine_padding, freq=freq, + use_cache=use_cache, **kwargs, ) - head = self.horizon_ff_layer(outputs.last_hidden_state) - bsize, total_patches, _ = head.shape - num_outputs = 1 + len(self.config.quantiles) - head = head.view(bsize, total_patches, self.config.horizon_length, num_outputs) - - # Last fine patch index in the concatenated sequence. - fine_last_idx = total_patches - 1 - fine_patch = head[:, fine_last_idx, :, :] + mean_patch, quant_patch = self._project_last_fine(outputs, outputs.last_hidden_state.shape[1] - 1) + return mean_patch, quant_patch, outputs - loc = outputs.loc[:, None, None] - scale = outputs.scale[:, None, None] - mean_patch = fine_patch[..., 0] * scale[..., 0] + loc[..., 0] - quant_patch = fine_patch[..., 1:] * scale + loc - mean_patch = torch.nan_to_num(mean_patch, nan=0.0, posinf=0.0, neginf=0.0) - quant_patch = torch.nan_to_num(quant_patch, nan=0.0, posinf=0.0, neginf=0.0) + def _decode_step_incremental( + self, + new_fine_values: torch.Tensor, + freq: torch.Tensor, + past_key_values: Cache, + loc_fine: torch.Tensor, + scale_fine: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor, CtsmOutput]: + """Append `new_fine_values` to the cached state and run only the new positions through the stack.""" + outputs: CtsmOutput = self.model( + past_values_fine=new_fine_values, + freq=freq, + past_key_values=past_key_values, + loc_fine=loc_fine, + scale_fine=scale_fine, + **kwargs, + ) + mean_patch, quant_patch = self._project_last_fine(outputs, outputs.last_hidden_state.shape[1] - 1) return mean_patch, quant_patch, outputs @can_return_tuple @@ -600,6 +787,7 @@ def forward( future_values: torch.Tensor | None = None, horizon_len: int | None = None, freq: Sequence[int] | torch.Tensor | None = None, + use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], ) -> CtsmOutputForPrediction: r""" @@ -613,6 +801,11 @@ def forward( `config.horizon_length` trigger autoregressive decoding. freq (`Sequence[int]` or `torch.Tensor`, *optional*): Frequency indices. Defaults to zeros. + use_cache (`bool`, *optional*): + Whether to use a key/value cache across autoregressive decode steps. Defaults to `True` when + `horizon_len > config.horizon_length` (i.e. when AR decoding is needed) and `False` otherwise. + Set to `False` to force a full recompute at every AR step (matches the original reference + behaviour; slower but avoids the stream-stats-freezing approximation). """ device = self.horizon_ff_layer.input_layer.weight.device horizon_len = horizon_len or self.config.horizon_length @@ -635,21 +828,49 @@ def forward( mean_chunks: list[torch.Tensor] = [] quant_chunks: list[torch.Tensor] = [] remaining = horizon_len - coarse_buffer = torch.zeros((bsize, 0), dtype=torch.float32, device=device) last_outputs: CtsmOutput | None = None - max_coarse = self.config.context_length max_fine = self.config.context_length + max_coarse = self.config.context_length agg = self.config.agg_factor + new_fine_patches = self.config.horizon_length // self.config.patch_length + + past_key_values: Cache | None = None + frozen_loc_fine: torch.Tensor | None = None + frozen_scale_fine: torch.Tensor | None = None + coarse_buffer = torch.zeros((bsize, 0), dtype=torch.float32, device=device) + + if use_cache is None: + use_cache = num_decode_patches > 1 + pending_new_fine: torch.Tensor | None = None + + for step_idx in range(num_decode_patches): + if past_key_values is None: + # First step (or after cache invalidation): full forward. The coarse block in the cache + # stays frozen at the initial state — only the fine block grows via subsequent incremental + # steps — which matches how KV caches work for append-only sequences. + mean_patch, quant_patch, last_outputs = self._decode_step_full( + past_values_coarse=coarse, + past_values_fine=fine, + past_values_coarse_padding=coarse_pad, + past_values_fine_padding=fine_pad, + freq=freq_tensor, + use_cache=use_cache, + **kwargs, + ) + past_key_values = last_outputs.past_key_values + frozen_loc_fine = last_outputs.loc + frozen_scale_fine = last_outputs.scale + else: + # Incremental: only the fine values newly appended last step go through the stack. + mean_patch, quant_patch, last_outputs = self._decode_step_incremental( + new_fine_values=pending_new_fine, + freq=freq_tensor, + past_key_values=past_key_values, + loc_fine=frozen_loc_fine, + scale_fine=frozen_scale_fine, + **kwargs, + ) - for _ in range(num_decode_patches): - mean_patch, quant_patch, last_outputs = self._decode_step( - past_values_coarse=coarse, - past_values_fine=fine, - past_values_coarse_padding=coarse_pad, - past_values_fine_padding=fine_pad, - freq=freq_tensor, - **kwargs, - ) take = min(remaining, output_patch_len) mean_chunks.append(mean_patch[:, :take]) quant_chunks.append(quant_patch[:, :take, :]) @@ -657,8 +878,12 @@ def forward( if remaining <= 0: break - # Append fine predictions to fine context. - fine = torch.cat([fine, mean_patch[:, :output_patch_len]], dim=1) + new_fine = mean_patch[:, :output_patch_len] + pending_new_fine = new_fine + + # Track the raw contexts so the next full-forward (initial step or after cache + # invalidation) sees the right state. Mirrors the reference AR loop. + fine = torch.cat([fine, new_fine], dim=1) fine_pad = torch.cat( [fine_pad, torch.zeros((bsize, output_patch_len), device=device, dtype=fine_pad.dtype)], dim=1 ) @@ -666,8 +891,7 @@ def forward( fine = fine[:, -max_fine:] fine_pad = fine_pad[:, -max_fine:] - # Aggregate into coarse context when enough fine samples accumulated. - coarse_buffer = torch.cat([coarse_buffer, mean_patch[:, :output_patch_len]], dim=1) + coarse_buffer = torch.cat([coarse_buffer, new_fine], dim=1) full_blocks = coarse_buffer.shape[1] // agg if full_blocks > 0: blocks = coarse_buffer[:, : full_blocks * agg].view(bsize, full_blocks, agg).mean(dim=2) @@ -680,6 +904,12 @@ def forward( coarse = coarse[:, -max_coarse:] coarse_pad = coarse_pad[:, -max_coarse:] + if past_key_values is not None: + projected_len = past_key_values.get_seq_length() + new_fine_patches + if projected_len >= self.config.max_position_embeddings: + past_key_values = None + pending_new_fine = None + mean_predictions = torch.cat(mean_chunks, dim=1)[:, :horizon_len] full_predictions = torch.cat( [torch.cat(mean_chunks, dim=1)[:, :horizon_len, None], torch.cat(quant_chunks, dim=1)[:, :horizon_len, :]], diff --git a/tests/models/ctsm/test_modeling_ctsm.py b/tests/models/ctsm/test_modeling_ctsm.py index abda3ec19263..fa37870a9a3a 100644 --- a/tests/models/ctsm/test_modeling_ctsm.py +++ b/tests/models/ctsm/test_modeling_ctsm.py @@ -251,6 +251,32 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): inputs_dict["future_values"] = floats_tensor([batch_size, self.model_tester.horizon_length], rng=rng) return inputs_dict + def test_kv_cache_matches_full_recompute(self): + """Cached autoregressive decoding should produce close-to-identical predictions to the + full-recompute path (the small gap is from the stream-stats-freezing approximation).""" + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = CtsmModelForPrediction(config).to(torch_device).eval() + + # Long enough to trigger AR (horizon > config.horizon_length). + horizon_len = config.horizon_length * 3 + with torch.no_grad(): + out_full = model(**inputs_dict, horizon_len=horizon_len, use_cache=False) + out_cache = model(**inputs_dict, horizon_len=horizon_len, use_cache=True) + + # First horizon_length predictions must match bit-exactly (step 1 is identical in both paths). + step1 = config.horizon_length + self.assertTrue( + torch.allclose(out_full.mean_predictions[:, :step1], out_cache.mean_predictions[:, :step1], atol=1e-5), + msg="Step-1 predictions must match bit-exactly between cached and non-cached paths.", + ) + # On subsequent AR steps the stats-freezing approximation introduces a small bounded drift. + # The bound is generous here because the tiny tester model has random weights and a horizon of 8, + # so compounding any small per-step shift over multiple steps is amplified. + relative = (out_full.mean_predictions - out_cache.mean_predictions).abs().max() / ( + out_full.mean_predictions.abs().max().clamp_min(1.0) + ) + self.assertLess(relative.item(), 0.5, f"cached vs full-recompute AR drift {relative.item():.2e} too large") + @require_torch @slow From 34a4cdf474666fbc56d2086cbc20ced51c2ed7a5 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 18 Apr 2026 13:58:01 +0200 Subject: [PATCH 9/9] ctsm: pass config to DynamicCache (Llama convention) --- src/transformers/models/ctsm/modeling_ctsm.py | 2 +- src/transformers/models/ctsm/modular_ctsm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/ctsm/modeling_ctsm.py b/src/transformers/models/ctsm/modeling_ctsm.py index 4d63554ff3a2..e0ae7cbff49e 100644 --- a/src/transformers/models/ctsm/modeling_ctsm.py +++ b/src/transformers/models/ctsm/modeling_ctsm.py @@ -860,7 +860,7 @@ def _full_forward( ) position_embeddings = self.rotary_emb(model_input, position_ids) - past_key_values = DynamicCache() if use_cache else None + past_key_values = DynamicCache(config=self.config) if use_cache else None hidden_states = model_input for layer in self.layers[: self.config.num_hidden_layers]: diff --git a/src/transformers/models/ctsm/modular_ctsm.py b/src/transformers/models/ctsm/modular_ctsm.py index 9ac80b54f7f6..e56fe16403c5 100644 --- a/src/transformers/models/ctsm/modular_ctsm.py +++ b/src/transformers/models/ctsm/modular_ctsm.py @@ -522,7 +522,7 @@ def _full_forward( ) position_embeddings = self.rotary_emb(model_input, position_ids) - past_key_values = DynamicCache() if use_cache else None + past_key_values = DynamicCache(config=self.config) if use_cache else None hidden_states = model_input for layer in self.layers[: self.config.num_hidden_layers]: