From fcfd9e709d4472e2ba8d84c7f51f5fa68fc44db0 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 27 Aug 2024 22:42:30 +0200 Subject: [PATCH 001/242] initial copy from t5 --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/timesfm.md | 71 + src/transformers/__init__.py | 24 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 8 + .../models/auto/tokenization_auto.py | 6 + src/transformers/models/timesfm/__init__.py | 66 + .../models/timesfm/configuration_timesfm.py | 163 ++ ...mesfm_original_tf_checkpoint_to_pytorch.py | 59 + .../convert_timesfmx_checkpoint_to_flax.py | 235 ++ .../convert_timesfmx_checkpoint_to_pytorch.py | 238 ++ .../models/timesfm/modeling_timesfm.py | 2388 +++++++++++++++++ tests/models/timesfm/__init__.py | 0 tests/models/timesfm/test_modeling_timesfm.py | 1459 ++++++++++ 15 files changed, 4722 insertions(+) create mode 100644 docs/source/en/model_doc/timesfm.md create mode 100644 src/transformers/models/timesfm/__init__.py create mode 100644 src/transformers/models/timesfm/configuration_timesfm.py create mode 100644 src/transformers/models/timesfm/convert_timesfm_original_tf_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_flax.py create mode 100644 src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/timesfm/modeling_timesfm.py create mode 100644 tests/models/timesfm/__init__.py create mode 100644 tests/models/timesfm/test_modeling_timesfm.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index c5e3bcddfca9..12b812359282 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -556,6 +556,8 @@ title: T5v1.1 - local: model_doc/tapex title: TAPEX + - local: model_doc/timesfm + title: TimesFM - local: model_doc/transfo-xl title: Transformer XL - local: model_doc/ul2 diff --git a/docs/source/en/model_doc/timesfm.md b/docs/source/en/model_doc/timesfm.md new file mode 100644 index 000000000000..9acc824f9e0f --- /dev/null +++ b/docs/source/en/model_doc/timesfm.md @@ -0,0 +1,71 @@ + + +# TimesFM + +## Overview + +The TimesFM model was proposed in []() by . + + +The abstract from the paper is the following: + +** + +Tips: + + + +This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). +The original code can be found [here](). + + +## TimesFMConfig + +[[autodoc]] TimesFMConfig + +## TimesFMModel + +[[autodoc]] TimesFMModel + - forward + +## TimesFMForConditionalGeneration + +[[autodoc]] TimesFMForConditionalGeneration + - forward + +## TimesFMEncoderModel + +[[autodoc]] TimesFMEncoderModel + - forward + +## TimesFMForSequenceClassification + +[[autodoc]] TimesFMForSequenceClassification + - forward + +## TimesFMForTokenClassification + +[[autodoc]] TimesFMForTokenClassification + - forward + +## TimesFMForQuestionAnswering + +[[autodoc]] TimesFMForQuestionAnswering + - forward + + + diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 41d932c0ff96..5179907d6b35 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -741,6 +741,7 @@ "models.swinv2": ["Swinv2Config"], "models.switch_transformers": ["SwitchTransformersConfig"], "models.t5": ["T5Config"], + "models.timesfm": ["TimesFMConfig"], "models.table_transformer": ["TableTransformerConfig"], "models.tapas": [ "TapasConfig", @@ -3354,6 +3355,18 @@ "load_tf_weights_in_t5", ] ) + _import_structure["models.timesfm"].extend( + [ + "TimesFMEncoderModel", + "TimesFMForConditionalGeneration", + "TimesFMForQuestionAnswering", + "TimesFMForSequenceClassification", + "TimesFMForTokenClassification", + "TimesFMModel", + "TimesFMPreTrainedModel", + "load_tf_weights_in_timesfm", + ] + ) _import_structure["models.table_transformer"].extend( [ "TableTransformerForObjectDetection", @@ -5536,6 +5549,7 @@ SwitchTransformersConfig, ) from .models.t5 import T5Config + from .models.timesfm import TimesFMConfig from .models.table_transformer import ( TableTransformerConfig, ) @@ -7747,6 +7761,16 @@ T5PreTrainedModel, load_tf_weights_in_t5, ) + from .models.timesfm import ( + TimesFMEncoderModel, + TimesFMForConditionalGeneration, + TimesFMForQuestionAnswering, + TimesFMForSequenceClassification, + TimesFMForTokenClassification, + TimesFMModel, + TimesFMPreTrainedModel, + load_tf_weights_in_timesfm, + ) from .models.table_transformer import ( TableTransformerForObjectDetection, TableTransformerModel, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 1afc3f5ae7ae..8190f96c1e9b 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -229,6 +229,7 @@ swinv2, switch_transformers, t5, + timesfm, table_transformer, tapas, time_series_transformer, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 625efc3a8c9d..d90e416d6957 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -254,6 +254,7 @@ ("swinv2", "Swinv2Config"), ("switch_transformers", "SwitchTransformersConfig"), ("t5", "T5Config"), + ("timesfm", "TimesFMConfig"), ("table-transformer", "TableTransformerConfig"), ("tapas", "TapasConfig"), ("time_series_transformer", "TimeSeriesTransformerConfig"), @@ -556,6 +557,7 @@ ("swinv2", "Swin Transformer V2"), ("switch_transformers", "SwitchTransformers"), ("t5", "T5"), + ("timesfm", "TimesFM"), ("t5v1.1", "T5v1.1"), ("table-transformer", "Table Transformer"), ("tapas", "TAPAS"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 7049d38e9fbf..a3919a2005c8 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -235,6 +235,7 @@ ("swinv2", "Swinv2Model"), ("switch_transformers", "SwitchTransformersModel"), ("t5", "T5Model"), + ("timesfm", "TimesFMModel"), ("table-transformer", "TableTransformerModel"), ("tapas", "TapasModel"), ("time_series_transformer", "TimeSeriesTransformerModel"), @@ -339,6 +340,7 @@ ("squeezebert", "SqueezeBertForMaskedLM"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), + ("timesfm", "TimesFMForConditionalGeneration"), ("tapas", "TapasForMaskedLM"), ("transfo-xl", "TransfoXLLMHeadModel"), ("tvlt", "TvltForPreTraining"), @@ -432,6 +434,7 @@ ("squeezebert", "SqueezeBertForMaskedLM"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), + ("timesfm", "TimesFMForConditionalGeneration"), ("tapas", "TapasForMaskedLM"), ("transfo-xl", "TransfoXLLMHeadModel"), ("wav2vec2", "Wav2Vec2ForMaskedLM"), @@ -845,6 +848,7 @@ ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToText"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), + ("timesfm", "TimesFMForConditionalGeneration"), ("umt5", "UMT5ForConditionalGeneration"), ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"), ] @@ -944,6 +948,7 @@ ("stablelm", "StableLmForSequenceClassification"), ("starcoder2", "Starcoder2ForSequenceClassification"), ("t5", "T5ForSequenceClassification"), + ("timesfm", "TimesFMForSequenceClassification"), ("tapas", "TapasForSequenceClassification"), ("transfo-xl", "TransfoXLForSequenceClassification"), ("umt5", "UMT5ForSequenceClassification"), @@ -1016,6 +1021,7 @@ ("splinter", "SplinterForQuestionAnswering"), ("squeezebert", "SqueezeBertForQuestionAnswering"), ("t5", "T5ForQuestionAnswering"), + ("timesfm", "TimesFMForQuestionAnswering"), ("umt5", "UMT5ForQuestionAnswering"), ("xlm", "XLMForQuestionAnsweringSimple"), ("xlm-roberta", "XLMRobertaForQuestionAnswering"), @@ -1116,6 +1122,7 @@ ("stablelm", "StableLmForTokenClassification"), ("starcoder2", "Starcoder2ForTokenClassification"), ("t5", "T5ForTokenClassification"), + ("timesfm", "TimesFMForTokenClassification"), ("umt5", "UMT5ForTokenClassification"), ("xlm", "XLMForTokenClassification"), ("xlm-roberta", "XLMRobertaForTokenClassification"), @@ -1337,6 +1344,7 @@ ("roformer", "RoFormerModel"), ("squeezebert", "SqueezeBertModel"), ("t5", "T5EncoderModel"), + ("timesfm", "TimesFMEncoderModel"), ("umt5", "UMT5EncoderModel"), ("xlm", "XLMModel"), ("xlm-roberta", "XLMRobertaModel"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index b094f50b5e97..25055a5baa20 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -472,6 +472,12 @@ "T5TokenizerFast" if is_tokenizers_available() else None, ), ), + ( + "timesfm", + ( + "T5Tokenizer" if is_sentencepiece_available() else None, + "T5TokenizerFast" if is_tokenizers_available() else None, + ), ("tapas", ("TapasTokenizer", None)), ("tapex", ("TapexTokenizer", None)), ("transfo-xl", ("TransfoXLTokenizer", None)), diff --git a/src/transformers/models/timesfm/__init__.py b/src/transformers/models/timesfm/__init__.py new file mode 100644 index 000000000000..7398dccbda88 --- /dev/null +++ b/src/transformers/models/timesfm/__init__.py @@ -0,0 +1,66 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, +) + + +_import_structure = {"configuration_timesfm": ["TimesFMConfig", "TimesFMOnnxConfig"]} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_timesfm"] = [ + "TimesFMEncoderModel", + "TimesFMForConditionalGeneration", + "TimesFMModel", + "TimesFMPreTrainedModel", + "load_tf_weights_in_timesfm", + "TimesFMForQuestionAnswering", + "TimesFMForSequenceClassification", + "TimesFMForTokenClassification", + ] + +if TYPE_CHECKING: + from .configuration_timesfm import TimesFMConfig, TimesFMOnnxConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_timesfm import ( + TimesFMEncoderModel, + TimesFMForConditionalGeneration, + TimesFMForQuestionAnswering, + TimesFMForSequenceClassification, + TimesFMForTokenClassification, + TimesFMModel, + TimesFMPreTrainedModel, + load_tf_weights_in_timesfm, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py new file mode 100644 index 000000000000..065f779b557c --- /dev/null +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -0,0 +1,163 @@ +# coding=utf-8 +# Copyright 2020, The TimesFM Authors and HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TimesFM model configuration""" + +from typing import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxSeq2SeqConfigWithPast +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class TimesFMConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`TimesFMModel`] or a [`TFTimesFMModel`]. It is used to + instantiate a TimesFM model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the TimesFM + [google/timesfm-1.0-200m](https://huggingface.co/google/timesfm-1.0-200m) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Arguments: + vocab_size (`int`, *optional*, defaults to 32128): + Vocabulary size of the TimesFM model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`TimesFMModel`] or [`TFTimesFMModel`]. + d_model (`int`, *optional*, defaults to 512): + Size of the encoder layers and the pooler layer. + d_kv (`int`, *optional*, defaults to 64): + Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will + be defined as `num_heads * d_kv`. + d_ff (`int`, *optional*, defaults to 2048): + Size of the intermediate feed forward layer in each `TimesFMBlock`. + num_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer encoder. + num_decoder_layers (`int`, *optional*): + Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set. + num_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + relative_attention_num_buckets (`int`, *optional*, defaults to 32): + The number of buckets to use for each attention layer. + relative_attention_max_distance (`int`, *optional*, defaults to 128): + The maximum distance of the longer sequences for the bucket separation. + dropout_rate (`float`, *optional*, defaults to 0.1): + The ratio for all dropout layers. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier. + layer_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the layer normalization layers. + initializer_factor (`float`, *optional*, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + feed_forward_proj (`string`, *optional*, defaults to `"relu"`): + Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. TimesFMv1.1 uses the + `"gated-gelu"` feed forward projection. Original TimesFM uses `"relu"`. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + """ + + model_type = "timesfm" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + + def __init__( + self, + vocab_size=32128, + d_model=512, + d_kv=64, + d_ff=2048, + num_layers=6, + num_decoder_layers=None, + num_heads=8, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + dropout_rate=0.1, + layer_norm_epsilon=1e-6, + initializer_factor=1.0, + feed_forward_proj="relu", + is_encoder_decoder=True, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + classifier_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.d_model = d_model + self.d_kv = d_kv + self.d_ff = d_ff + self.num_layers = num_layers + self.num_decoder_layers = ( + num_decoder_layers if num_decoder_layers is not None else self.num_layers + ) # default = symmetry + self.num_heads = num_heads + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.dropout_rate = dropout_rate + self.classifier_dropout = classifier_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_factor = initializer_factor + self.feed_forward_proj = feed_forward_proj + self.use_cache = use_cache + + act_info = self.feed_forward_proj.split("-") + self.dense_act_fn = act_info[-1] + self.is_gated_act = act_info[0] == "gated" + + if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2: + raise ValueError( + f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. " + "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. " + "'gated-gelu' or 'relu'" + ) + + # for backwards compatibility + if feed_forward_proj == "gated-gelu": + self.dense_act_fn = "gelu_new" + + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + + +class TimesFMOnnxConfig(OnnxSeq2SeqConfigWithPast): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = { + "input_ids": {0: "batch", 1: "encoder_sequence"}, + "attention_mask": {0: "batch", 1: "encoder_sequence"}, + } + if self.use_past: + common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence" + common_inputs["decoder_input_ids"] = {0: "batch"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} + else: + common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} + + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + + return common_inputs + + @property + def default_onnx_opset(self) -> int: + return 13 diff --git a/src/transformers/models/timesfm/convert_timesfm_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/timesfm/convert_timesfm_original_tf_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..aa66a8392d4f --- /dev/null +++ b/src/transformers/models/timesfm/convert_timesfm_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,59 @@ +# coding=utf-8 +# Copyright 2024 The TimesFM authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert TimesFM checkpoint.""" + +import argparse + +from transformers import TimesFMConfig, TimesFMForConditionalGeneration, load_tf_weights_in_timesfm +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): + # Initialise PyTorch model + config = TimesFMConfig.from_json_file(config_file) + print(f"Building PyTorch model from configuration: {config}") + model = TimesFMForConditionalGeneration(config) + + # Load weights from tf checkpoint + load_tf_weights_in_timesfm(model, config, tf_checkpoint_path) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + model.save_pretrained(pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained TimesFM model. \nThis specifies the model architecture." + ), + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_flax.py b/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_flax.py new file mode 100644 index 000000000000..98570e22876e --- /dev/null +++ b/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_flax.py @@ -0,0 +1,235 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert TimesFMX checkpoints from the original repository to JAX/FLAX model.""" + +import argparse + +from timesfmx import checkpoints + +from transformers import FlaxTimesFMForConditionalGeneration, TimesFMConfig + + +def convert_timesfmx_checkpoint_to_flax(timesfmx_checkpoint_path, config_name, flax_dump_folder_path): + config = TimesFMConfig.from_pretrained(config_name) + flax_model = FlaxTimesFMForConditionalGeneration(config=config) + timesfmx_model = checkpoints.load_timesfmx_checkpoint(timesfmx_checkpoint_path) + + split_mlp_wi = "wi_0" in timesfmx_model["target"]["encoder"]["layers_0"]["mlp"] + + # Encoder + for layer_index in range(config.num_layers): + layer_name = f"layers_{str(layer_index)}" + + # Self-Attention + timesfmx_attention_key = timesfmx_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"] + timesfmx_attention_out = timesfmx_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"] + timesfmx_attention_query = timesfmx_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"] + timesfmx_attention_value = timesfmx_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"] + + # Layer Normalization + timesfmx_attention_layer_norm = timesfmx_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"] + + if split_mlp_wi: + timesfmx_mlp_wi_0 = timesfmx_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"] + timesfmx_mlp_wi_1 = timesfmx_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"] + else: + timesfmx_mlp_wi = timesfmx_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"] + + timesfmx_mlp_wo = timesfmx_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"] + + # Layer Normalization + timesfmx_mlp_layer_norm = timesfmx_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"] + + # Assigning + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = ( + timesfmx_attention_key + ) + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = ( + timesfmx_attention_out + ) + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = ( + timesfmx_attention_query + ) + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = ( + timesfmx_attention_value + ) + + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = ( + timesfmx_attention_layer_norm + ) + + if split_mlp_wi: + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_0"][ + "kernel" + ] = timesfmx_mlp_wi_0 + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_1"][ + "kernel" + ] = timesfmx_mlp_wi_1 + else: + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi"]["kernel"] = ( + timesfmx_mlp_wi + ) + + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wo"]["kernel"] = ( + timesfmx_mlp_wo + ) + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = ( + timesfmx_mlp_layer_norm + ) + + # Only for layer 0: + timesfmx_encoder_rel_embedding = timesfmx_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T + flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][ + "embedding" + ] = timesfmx_encoder_rel_embedding + + # Assigning + timesfmx_encoder_norm = timesfmx_model["target"]["encoder"]["encoder_norm"]["scale"] + flax_model.params["encoder"]["final_layer_norm"]["weight"] = timesfmx_encoder_norm + + # Decoder + for layer_index in range(config.num_decoder_layers): + layer_name = f"layers_{str(layer_index)}" + + # Self-Attention + timesfmx_attention_key = timesfmx_model["target"]["decoder"][layer_name]["self_attention"]["key"]["kernel"] + timesfmx_attention_out = timesfmx_model["target"]["decoder"][layer_name]["self_attention"]["out"]["kernel"] + timesfmx_attention_query = timesfmx_model["target"]["decoder"][layer_name]["self_attention"]["query"]["kernel"] + timesfmx_attention_value = timesfmx_model["target"]["decoder"][layer_name]["self_attention"]["value"]["kernel"] + + # Layer Normalization + timesfmx_pre_attention_layer_norm = timesfmx_model["target"]["decoder"][layer_name]["pre_self_attention_layer_norm"][ + "scale" + ] + + # Encoder-Decoder-Attention + timesfmx_enc_dec_attention_key = timesfmx_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["key"][ + "kernel" + ] + timesfmx_enc_dec_attention_out = timesfmx_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["out"][ + "kernel" + ] + timesfmx_enc_dec_attention_query = timesfmx_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["query"][ + "kernel" + ] + timesfmx_enc_dec_attention_value = timesfmx_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["value"][ + "kernel" + ] + + # Layer Normalization + timesfmx_cross_layer_norm = timesfmx_model["target"]["decoder"][layer_name]["pre_cross_attention_layer_norm"]["scale"] + + # MLP + if split_mlp_wi: + timesfmx_mlp_wi_0 = timesfmx_model["target"]["decoder"][layer_name]["mlp"]["wi_0"]["kernel"] + timesfmx_mlp_wi_1 = timesfmx_model["target"]["decoder"][layer_name]["mlp"]["wi_1"]["kernel"] + else: + timesfmx_mlp_wi = timesfmx_model["target"]["decoder"][layer_name]["mlp"]["wi"]["kernel"] + + timesfmx_mlp_wo = timesfmx_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"] + + # Layer Normalization + tx5_mlp_layer_norm = timesfmx_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"] + + # Assigning + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = ( + timesfmx_attention_key + ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = ( + timesfmx_attention_out + ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = ( + timesfmx_attention_query + ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = ( + timesfmx_attention_value + ) + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = ( + timesfmx_pre_attention_layer_norm + ) + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["k"]["kernel"] = ( + timesfmx_enc_dec_attention_key + ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["o"]["kernel"] = ( + timesfmx_enc_dec_attention_out + ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["q"]["kernel"] = ( + timesfmx_enc_dec_attention_query + ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["v"]["kernel"] = ( + timesfmx_enc_dec_attention_value + ) + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = ( + timesfmx_cross_layer_norm + ) + + if split_mlp_wi: + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_0"][ + "kernel" + ] = timesfmx_mlp_wi_0 + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_1"][ + "kernel" + ] = timesfmx_mlp_wi_1 + else: + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi"]["kernel"] = ( + timesfmx_mlp_wi + ) + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wo"]["kernel"] = ( + timesfmx_mlp_wo + ) + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["layer_norm"]["weight"] = ( + tx5_mlp_layer_norm + ) + + # Decoder Normalization + tx5_decoder_norm = timesfmx_model["target"]["decoder"]["decoder_norm"]["scale"] + flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm + + # Only for layer 0: + timesfmx_decoder_rel_embedding = timesfmx_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T + flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][ + "embedding" + ] = timesfmx_decoder_rel_embedding + + # Token Embeddings + tx5_token_embeddings = timesfmx_model["target"]["token_embedder"]["embedding"] + flax_model.params["shared"]["embedding"] = tx5_token_embeddings + + # LM Head (only in v1.1 checkpoints) + if "logits_dense" in timesfmx_model["target"]["decoder"]: + flax_model.params["lm_head"]["kernel"] = timesfmx_model["target"]["decoder"]["logits_dense"]["kernel"] + + flax_model.save_pretrained(flax_dump_folder_path) + print("TimesFMX Model was sucessfully converted!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--timesfmx_checkpoint_path", default=None, type=str, required=True, help="Path the TX5 checkpoint." + ) + parser.add_argument("--config_name", default=None, type=str, required=True, help="Config name of TimesFM model.") + parser.add_argument( + "--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model." + ) + args = parser.parse_args() + convert_timesfmx_checkpoint_to_flax(args.timesfmx_checkpoint_path, args.config_name, args.flax_dump_folder_path) diff --git a/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_pytorch.py b/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..b761d76bbdcd --- /dev/null +++ b/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_pytorch.py @@ -0,0 +1,238 @@ +# coding=utf-8 +# Copyright 2024 Google LLC and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Convert TimesFMX checkpoint to PyTorch + +Steps: +- Install gsutil according to https://cloud.google.com/storage/docs/gsutil_install +- Get a TimesFMX checkpoint at https://github.com/google-research/timesfmx/blob/main/docs/models.md#timesfm-11-checkpoints Example: + `gsutil -m cp -r gs://timesfm-data/pretrained_models/timesfmx/timesfm_1_1_small $HOME/` +- Create or download a corresponding config for the downloaded model. E.g. for TimesFM v1.1 small, you can use + https://huggingface.co/google/timesfm-v1_1-small/blob/main/config.json +- Convert: + ``` + python3 convert_timesfmx_checkpoint_to_pytorch.py --timesfmx_checkpoint_path=$HOME/timesfm_1_1_small --config_file=config.json\ + --pytorch_dump_path=$HOME/timesfm_1_1_small_pt + ``` +""" + +import argparse +import collections + +import torch +from flax import traverse_util +from timesfmx import checkpoints + +from transformers import TimesFMConfig, TimesFMEncoderModel, TimesFMForConditionalGeneration +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def timesfmx_attention_lookup(params, i, prefix, layer_name="attention"): + """Returns the KOQV parameters of (self-)attention. Does not transpose.""" + k = params[f"{prefix}/layers_{i}/{layer_name}/key/kernel"] + o = params[f"{prefix}/layers_{i}/{layer_name}/out/kernel"] + q = params[f"{prefix}/layers_{i}/{layer_name}/query/kernel"] + v = params[f"{prefix}/layers_{i}/{layer_name}/value/kernel"] + return k, o, q, v + + +def timesfmx_mlp_lookup(params, i, prefix, split_mlp_wi=False): + """Returns the MLP parameters of a layer. Does not transpose.""" + if split_mlp_wi: + wi_0 = params[f"{prefix}/layers_{i}/mlp/wi_0/kernel"] + wi_1 = params[f"{prefix}/layers_{i}/mlp/wi_1/kernel"] + wi = (wi_0, wi_1) + else: + wi = params[f"{prefix}/layers_{i}/mlp/wi/kernel"] + + wo = params[f"{prefix}/layers_{i}/mlp/wo/kernel"] + return wi, wo + + +def timesfmx_layer_norm_lookup(params, i, prefix, layer_name): + """Returns the layer norm param of a layer.""" + return params[f"{prefix}/layers_{i}/{layer_name}/scale"] + + +def convert_timesfmx_to_pytorch(variables: dict, *, num_layers: int, num_decoder_layers: int, is_encoder_only: bool): + """Converts the parameters from TimesFMX-Flax to Transformers-PyTorch.""" + old = traverse_util.flatten_dict(variables["target"]) + old = {"/".join(k): v for k, v in old.items()} + + # v1.1 models have a gated GeLU with wi_0 and wi_1 instead of wi + split_mlp_wi = "encoder/layers_0/mlp/wi_0/kernel" in old + print("Split MLP:", split_mlp_wi) + + new = collections.OrderedDict() + + # Shared embeddings. + new["shared.weight"] = old["token_embedder/embedding"] + + # Encoder. + for i in range(num_layers): + # Block i, layer 0 (Self Attention). + layer_norm = timesfmx_layer_norm_lookup(old, i, "encoder", "pre_attention_layer_norm") + k, o, q, v = timesfmx_attention_lookup(old, i, "encoder", "attention") + new[f"encoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm + new[f"encoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T + new[f"encoder.block.{i}.layer.0.SelfAttention.o.weight"] = o.T + new[f"encoder.block.{i}.layer.0.SelfAttention.q.weight"] = q.T + new[f"encoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T + + # Block i, layer 1 (MLP). + layer_norm = timesfmx_layer_norm_lookup(old, i, "encoder", "pre_mlp_layer_norm") + wi, wo = timesfmx_mlp_lookup(old, i, "encoder", split_mlp_wi) + new[f"encoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm + if split_mlp_wi: + new[f"encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight"] = wi[0].T + new[f"encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight"] = wi[1].T + else: + new[f"encoder.block.{i}.layer.1.DenseReluDense.wi.weight"] = wi.T + new[f"encoder.block.{i}.layer.1.DenseReluDense.wo.weight"] = wo.T + + new["encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = old[ + "encoder/relpos_bias/rel_embedding" + ].T + new["encoder.final_layer_norm.weight"] = old["encoder/encoder_norm/scale"] + + if not is_encoder_only: + # Decoder. + for i in range(num_decoder_layers): + # Block i, layer 0 (Self Attention). + layer_norm = timesfmx_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm") + k, o, q, v = timesfmx_attention_lookup(old, i, "decoder", "self_attention") + new[f"decoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm + new[f"decoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T + new[f"decoder.block.{i}.layer.0.SelfAttention.o.weight"] = o.T + new[f"decoder.block.{i}.layer.0.SelfAttention.q.weight"] = q.T + new[f"decoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T + + # Block i, layer 1 (Cross Attention). + layer_norm = timesfmx_layer_norm_lookup(old, i, "decoder", "pre_cross_attention_layer_norm") + k, o, q, v = timesfmx_attention_lookup(old, i, "decoder", "encoder_decoder_attention") + new[f"decoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm + new[f"decoder.block.{i}.layer.1.EncDecAttention.k.weight"] = k.T + new[f"decoder.block.{i}.layer.1.EncDecAttention.o.weight"] = o.T + new[f"decoder.block.{i}.layer.1.EncDecAttention.q.weight"] = q.T + new[f"decoder.block.{i}.layer.1.EncDecAttention.v.weight"] = v.T + + # Block i, layer 2 (MLP). + layer_norm = timesfmx_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm") + wi, wo = timesfmx_mlp_lookup(old, i, "decoder", split_mlp_wi) + new[f"decoder.block.{i}.layer.2.layer_norm.weight"] = layer_norm + if split_mlp_wi: + new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight"] = wi[0].T + new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight"] = wi[1].T + else: + new[f"decoder.block.{i}.layer.2.DenseReluDense.wi.weight"] = wi.T + new[f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight"] = wo.T + + new["decoder.final_layer_norm.weight"] = old["decoder/decoder_norm/scale"] + new["decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = old[ + "decoder/relpos_bias/rel_embedding" + ].T + + # LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead) + if "decoder/logits_dense/kernel" in old: + new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T + + return new + + +def make_state_dict(converted_params, is_encoder_only: bool): + """Prepares a state dict for the PyTorch model.""" + # Make a state dict with torch tensors. + state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()]) + + # Add what is missing. + if "encoder.embed_tokens.weight" not in state_dict: + state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"] + + if not is_encoder_only: + if "decoder.embed_tokens.weight" not in state_dict: + state_dict["decoder.embed_tokens.weight"] = state_dict["shared.weight"] + + if "lm_head.weight" not in state_dict: # For old 1.0 models. + print("Using shared word embeddings as lm_head.") + state_dict["lm_head.weight"] = state_dict["shared.weight"] + + return state_dict + + +def load_timesfmx_weights_in_timesfm(model, config, timesfmx_checkpoint_path, is_encoder_only): + """Replaces the params in model witht the TimesFMX converted params.""" + variables = checkpoints.load_timesfmx_checkpoint(timesfmx_checkpoint_path) + converted = convert_timesfmx_to_pytorch( + variables, + num_layers=config.num_layers, + num_decoder_layers=config.num_decoder_layers, + is_encoder_only=is_encoder_only, + ) + state_dict = make_state_dict(converted, is_encoder_only) + model.load_state_dict(state_dict, strict=True) + + +def convert_timesfmx_checkpoint_to_pytorch( + timesfmx_checkpoint_path, config_file, pytorch_dump_path, is_encoder_only: bool = False +): + """Loads the config and model, converts the TimesFMX checkpoint, and saves a PyTorch checkpoint.""" + # Initialise PyTorch model + config = TimesFMConfig.from_json_file(config_file) + print(f"Building PyTorch model from configuration: {config}") + # Non-v1.1 checkpoints could also use TimesFMModel, but this works for all. + # The v1.0 checkpoints will simply have an LM head that is the word embeddings. + if is_encoder_only: + model = TimesFMEncoderModel(config) + else: + model = TimesFMForConditionalGeneration(config) + + # Load weights from tf checkpoint + load_timesfmx_weights_in_timesfm(model, config, timesfmx_checkpoint_path, is_encoder_only) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + model.save_pretrained(pytorch_dump_path) + + # Verify that we can load the checkpoint. + model.from_pretrained(pytorch_dump_path) + print("Done") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Converts a native TimesFMX checkpoint into a PyTorch checkpoint.") + # Required parameters + parser.add_argument( + "--timesfmx_checkpoint_path", default=None, type=str, required=True, help="Path to the TimesFMX checkpoint." + ) + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help="The config json file corresponding to the pre-trained TimesFM model.\nThis specifies the model architecture.", + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--is_encoder_only", action="store_true", help="Check if the model is encoder-decoder model", default=False + ) + args = parser.parse_args() + convert_timesfmx_checkpoint_to_pytorch( + args.timesfmx_checkpoint_path, args.config_file, args.pytorch_dump_path, args.is_encoder_only + ) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py new file mode 100644 index 000000000000..ea6daa33ac6b --- /dev/null +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -0,0 +1,2388 @@ +# coding=utf-8 +# Copyright 2024 Mesh TensorFlow authors, TimesFM Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch TimesFM model.""" + +import copy +import math +import os +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqQuestionAnsweringModelOutput, + Seq2SeqSequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_fx_proxy, + logging, + replace_return_docstrings, +) +from ...utils.model_parallel_utils import assert_device_map, get_device_map +from .configuration_timesfm import TimesFMConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "TimesFMConfig" +_CHECKPOINT_FOR_DOC = "google/timesfm-1.0-200m" + +#################################################### +# This dict contains ids and associated url +# for the pretrained weights provided with the models +#################################################### + + +#################################################### +# This is a conversion method from TF 1.0 to PyTorch +# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 +#################################################### +# Copied from transformers.models.t5.modeling_t5.load_tf_weights_in_t5 with t5->timesfm +def load_tf_weights_in_timesfm(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + tf_weights[name] = array + + for txt_name in names: + name = txt_name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + if "_slot_" in name[-1]: + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + pointer = model + array = tf_weights[txt_name] + + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + elif scope_names[0] == "self_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[0] + elif scope_names[0] == "enc_dec_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[1] + elif scope_names[0] == "dense_relu_dense": + pointer = getattr(pointer, "layer") + pointer = pointer[2] + elif scope_names[0] == "rms_norm": + if hasattr(pointer, "layer_norm"): + pointer = getattr(pointer, "layer_norm") + elif hasattr(pointer, "final_layer_norm"): + pointer = getattr(pointer, "final_layer_norm") + elif scope_names[0] == "scale": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + elif scope_names[0] == "decoder" and name[1] == "logits": + continue + elif scope_names[0] == "logits": + pointer = getattr(pointer, "lm_head") + elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): + pointer = getattr(pointer, f"wi_{scope_names[1]}") + continue + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if scope_names[0] not in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + if scope_names[0] != "embedding": + logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array.astype(np.float32)) + tf_weights.pop(txt_name, None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") + return model + + +#################################################### +# PyTorch Models are constructed by sub-classing +# - torch.nn.Module for the layers and +# - PreTrainedModel for the models (it-self a sub-class of nn.Module) +#################################################### +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, *optional*): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the timesfm models have the + following number of attention modules: + + - google/timesfm-1.0-200m: 6 + - google-timesfm/timesfm-base: 12 + - google-timesfm/timesfm-large: 24 + - google-timesfm/timesfm-3b: 24 + - google-timesfm/timesfm-11b: 24 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using google-timesfm/timesfm-3b, which has a total of 24 attention modules: + model = TimesFMForConditionalGeneration.from_pretrained("google-timesfm/timesfm-3b") + device_map = { + 0: [0, 1, 2], + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with google-timesfm/timesfm-3b: + model = TimesFMForConditionalGeneration.from_pretrained("google-timesfm/timesfm-3b") + device_map = { + 0: [0, 1, 2], + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +# Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->TimesFM +class TimesFMLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the TimesFM style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # TimesFM uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +try: + from apex.normalization import FusedRMSNorm + + TimesFMLayerNorm = FusedRMSNorm # noqa + + logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of TimesFMLayerNorm") +except ImportError: + # using the normal TimesFMLayerNorm + pass +except Exception: + logger.warning("discovered apex but it failed to load, falling back to TimesFMLayerNorm") + pass + +ALL_LAYERNORM_LAYERS.append(TimesFMLayerNorm) + + +# Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->TimesFM +class TimesFMDenseActDense(nn.Module): + def __init__(self, config: TimesFMConfig): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->TimesFM,t5->timesfm +class TimesFMDenseGatedActDense(nn.Module): + def __init__(self, config: TimesFMConfig): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + # To make 8bit quantization work for google/flan-timesfm-xxl, self.wo is kept in float32. + # See https://github.com/huggingface/transformers/issues/20287 + # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->TimesFM +class TimesFMLayerFF(nn.Module): + def __init__(self, config: TimesFMConfig): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = TimesFMDenseGatedActDense(config) + else: + self.DenseReluDense = TimesFMDenseActDense(config) + + self.layer_norm = TimesFMLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5Attention with T5->TimesFM +class TimesFMAttention(nn.Module): + def __init__(self, config: TimesFMConfig, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = False + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + if len(past_key_value) != 2: + raise ValueError( + f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + ) + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + def shape(states): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def unshape(states): + """reshape""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + elif past_key_value.shape[2] != key_value_states.shape[1]: + # checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + ) + value_states = project( + hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + ) + + # compute scores + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( + scores + ) # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) # (batch_size, n_heads, seq_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + + present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->TimesFM +class TimesFMLayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.SelfAttention = TimesFMAttention(config, has_relative_attention_bias=has_relative_attention_bias) + self.layer_norm = TimesFMLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->TimesFM +class TimesFMLayerCrossAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.EncDecAttention = TimesFMAttention(config, has_relative_attention_bias=False) + self.layer_norm = TimesFMLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5Block with T5->TimesFM +class TimesFMBlock(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append(TimesFMLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + if self.is_decoder: + self.layer.append(TimesFMLayerCrossAttention(config)) + + self.layer.append(TimesFMLayerFF(config)) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): + if past_key_value is not None: + if not self.is_decoder: + logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f"There should be {expected_num_past_key_values} past states. " + f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}" + f"Got {len(past_key_value)} past key / value states" + ) + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = present_key_value_state[0].shape[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = present_key_value_state + cross_attention_outputs[1] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (present_key_value_state,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +# Copied from transformers.models.t5.modeling_t5.T5ClassificationHead with T5->TimesFM +class TimesFMClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config: TimesFMConfig): + super().__init__() + self.dense = nn.Linear(config.d_model, config.d_model) + self.dropout = nn.Dropout(p=config.classifier_dropout) + self.out_proj = nn.Linear(config.d_model, config.num_labels) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel with T5->TimesFM,t5->timesfm +class TimesFMPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = TimesFMConfig + load_tf_weights = load_tf_weights_in_timesfm + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["TimesFMBlock"] + _keep_in_fp32_modules = ["wo"] + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "decoder_input_ids": input_ids, + "input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, TimesFMLayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance( + module, + (TimesFMModel, TimesFMForConditionalGeneration, TimesFMEncoderModel, TimesFMForQuestionAnswering), + ): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: + module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "qa_outputs"): + module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.qa_outputs.bias.data.zero_() + elif isinstance(module, TimesFMForTokenClassification): + if hasattr(module, "classifier"): + module.classifier.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.classifier.bias.data.zero_() + elif isinstance(module, TimesFMClassificationHead): + module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.dense, "bias") and module.dense.bias is not None: + module.dense.bias.data.zero_() + module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: + module.out_proj.bias.data.zero_() + elif isinstance(module, TimesFMDenseActDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, TimesFMDenseGatedActDense): + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, TimesFMAttention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + if decoder_start_token_id is None: + raise ValueError( + "self.model.config.decoder_start_token_id has to be defined. In TimesFM it is usually set to the pad_token_id. " + "See TimesFM docs for more information." + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +# Copied from transformers.models.t5.modeling_t5.T5Stack with T5->TimesFM +class TimesFMStack(TimesFMPreTrainedModel): + def __init__(self, config, embed_tokens=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.block = nn.ModuleList( + [TimesFMBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + ) + self.final_layer_norm = TimesFMLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.post_init() + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`TimesFMStack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model" + " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0," + " 'block.1': 1, ...}", + FutureWarning, + ) + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.block)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + # Load onto devices + for k, v in self.device_map.items(): + for layer in v: + cuda_device = "cuda:" + str(k) + self.block[layer] = self.block[layer].to(cuda_device) + + # Set embed_tokens to first layer + self.embed_tokens = self.embed_tokens.to(self.first_device) + # Set final layer norm to last device + self.final_layer_norm = self.final_layer_norm.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + for i in range(len(self.block)): + self.block[i] = self.block[i].to("cpu") + self.embed_tokens = self.embed_tokens.to("cpu") + self.final_layer_norm = self.final_layer_norm.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if inputs_embeds is None: + if self.embed_tokens is None: + raise ValueError("You have to initialize the model with valid token embeddings") + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + if use_cache is True: + if not self.is_decoder: + raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long + ) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.forward, + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +TIMESFM_START_DOCSTRING = r""" + + The TIMESFM model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text + Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan + Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a + text-to-text denoising generative setting. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`TimesFMConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +TIMESFM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. TIMESFM is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [TIMESFM Training](./timesfm#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + TIMESFM uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [TIMESFM + Training](./timesfm#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +TIMESFM_ENCODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. TIMESFM is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + To know more on how to prepare `input_ids` for pretraining take a look a [TIMESFM Training](./timesfm#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, +num_heads)`. +""" + + +@add_start_docstrings( + "The bare TIMESFM Model transformer outputting raw hidden-states without any specific head on top.", + TIMESFM_START_DOCSTRING, +) +class TimesFMModel(TimesFMPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [ + "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", + ] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: TimesFMConfig): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = TimesFMStack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = TimesFMStack(decoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`TimesFMModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model" + " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'encoder.block.0':" + " 0, 'encoder.block.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, TimesFMModel + + >>> tokenizer = AutoTokenizer.from_pretrained("google/timesfm-1.0-200m") + >>> model = TimesFMModel.from_pretrained("google/timesfm-1.0-200m") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for TimesFMModel. + >>> # This is not needed for torch's TimesFMForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings("""TIMESFM Model with a `language modeling` head on top.""", TIMESFM_START_DOCSTRING) +class TimesFMForConditionalGeneration(TimesFMPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [ + "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", + ] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + + def __init__(self, config: TimesFMConfig): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = TimesFMStack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = TimesFMStack(decoder_config, self.shared) + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`TimesFMForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you" + " should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also" + " provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance" + " {'encoder.block.0': 0, 'encoder.block.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.decoder.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., + config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for + labels in `[0, ..., config.vocab_size]` + + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, TimesFMForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("google/timesfm-1.0-200m") + >>> model = TimesFMForConditionalGeneration.from_pretrained("google/timesfm-1.0-200m") + + >>> # training + >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids + >>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> # inference + >>> input_ids = tokenizer( + ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model.generate(input_ids) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + >>> # studies have shown that owning a dog is good for you. + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + # move labels to correct device to enable PP + labels = labels.to(lm_logits.device) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + decoder_attention_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past_key_values is used + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + return { + "decoder_input_ids": input_ids, + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "decoder_attention_mask": decoder_attention_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past_key_values, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past_key_values is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past_key_values + + reordered_decoder_past = () + for layer_past_states in past_key_values: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), + ) + + if reordered_layer_past_states[0].shape != layer_past_states[0].shape: + raise ValueError( + f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched" + ) + if len(reordered_layer_past_states) != len(layer_past_states): + raise ValueError( + f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched" + ) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + + +@add_start_docstrings( + "The bare TIMESFM Model transformer outputting encoder's raw hidden-states without any specific head on top.", + TIMESFM_START_DOCSTRING, +) +class TimesFMEncoderModel(TimesFMPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight"] + _keys_to_ignore_on_load_unexpected = [r"decoder"] + + def __init__(self, config: TimesFMConfig): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = TimesFMStack(encoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`TimesFMEncoderModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load" + " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0," + " 'block.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.encoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(TIMESFM_ENCODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, TimesFMEncoderModel + + >>> tokenizer = AutoTokenizer.from_pretrained("google/timesfm-1.0-200m") + >>> model = TimesFMEncoderModel.from_pretrained("google/timesfm-1.0-200m") + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return encoder_outputs + + +@add_start_docstrings( + """ + TIMESFM model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE + tasks. + """, + TIMESFM_START_DOCSTRING, +) +class TimesFMForSequenceClassification(TimesFMPreTrainedModel): + _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: TimesFMConfig): + super().__init__(config) + self.transformer = TimesFMModel(config) + self.classification_head = TimesFMClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + self.model_parallel = False + + @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + if input_ids is None and inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}" + ) + + # decoder_input_ids from input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + decoder_input_ids = self._shift_right(input_ids) + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + + eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device) + + if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + batch_size, _, hidden_size = sequence_output.shape + sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :] + logits = self.classification_head(sentence_representation) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.config.num_labels == 1: + self.config.problem_type = "regression" + elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.config.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSequenceClassifierOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@add_start_docstrings( + """ + TIMESFM Encoder Model with a token classification head on top (a linear layer on top of the hidden-states output) + e.g. for Named-Entity-Recognition (NER) tasks. + """, + TIMESFM_START_DOCSTRING, +) +class TimesFMForTokenClassification(TimesFMPreTrainedModel): + _tied_weights_keys = ["transformer.encoder.embed_tokens.weight"] + + def __init__(self, config: TimesFMConfig): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = TimesFMEncoderModel(config) + self.dropout = nn.Dropout(config.classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits, outputs[2:-1]) + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + TIMESFM Model with a span classification head on top for extractive question-answering tasks like SQuAD (linear layers + on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + TIMESFM_START_DOCSTRING, +) +class TimesFMForQuestionAnswering(TimesFMPreTrainedModel): + _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: TimesFMConfig): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = TimesFMStack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = TimesFMStack(decoder_config, self.shared) + + self.num_labels = config.num_labels + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + self.model_parallel = False + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence + are not taken into account for computing the loss. + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache + if start_positions is not None and end_positions is not None: + use_cache = False + + # different to other models, TIMESFM automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + decoder_input_ids = self._shift_right(input_ids) + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=None, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + decoder_outputs[1:] + encoder_outputs + return ((total_loss,) + output) if total_loss is not None else output + + return Seq2SeqQuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) diff --git a/tests/models/timesfm/__init__.py b/tests/models/timesfm/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py new file mode 100644 index 000000000000..e5878f8c51c7 --- /dev/null +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -0,0 +1,1459 @@ +# coding=utf-8 +# Copyright 2024 Google TimesFM Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import copy +import os +import pickle +import tempfile +import unittest + +from transformers import TimesFMConfig, is_torch_available +from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES +from transformers.testing_utils import ( + require_accelerate, + require_sentencepiece, + require_tokenizers, + require_torch, + slow, + torch_device, +) +from transformers.utils import cached_property, is_torch_fx_available + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_fx_available(): + from transformers.utils.fx import symbolic_trace + + +if is_torch_available(): + import torch + + from transformers import ( + AutoTokenizer, + ByT5Tokenizer, + TimesFMEncoderModel, + TimesFMForConditionalGeneration, + TimesFMForQuestionAnswering, + TimesFMForSequenceClassification, + TimesFMForTokenClassification, + TimesFMModel, + T5Tokenizer, + ) + + +class TimesFMModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + encoder_seq_length=7, + decoder_seq_length=7, + # For common tests + is_training=True, + use_attention_mask=True, + use_labels=True, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + d_ff=37, + relative_attention_num_buckets=8, + dropout_rate=0.1, + initializer_factor=0.002, + eos_token_id=1, + pad_token_id=0, + decoder_start_token_id=0, + scope=None, + decoder_layers=None, + ): + self.parent = parent + self.batch_size = batch_size + self.encoder_seq_length = encoder_seq_length + self.decoder_seq_length = decoder_seq_length + # For common tests + self.seq_length = self.decoder_seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.d_ff = d_ff + self.relative_attention_num_buckets = relative_attention_num_buckets + self.dropout_rate = dropout_rate + self.initializer_factor = initializer_factor + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.decoder_start_token_id = decoder_start_token_id + self.scope = None + self.decoder_layers = decoder_layers + + def get_large_model_config(self): + return TimesFMConfig.from_pretrained("google-timesfm/timesfm-base") + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size).clamp(2) + input_ids[:, -1] = self.eos_token_id # Eos Token + decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + attention_mask = None + decoder_attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) + decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2) + + lm_labels = None + if self.use_labels: + lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + config = self.get_config() + + return ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) + + def get_pipeline_config(self): + return TimesFMConfig( + vocab_size=166, # timesfm forces 100 extra tokens + d_model=self.hidden_size, + d_ff=self.d_ff, + d_kv=self.hidden_size // self.num_attention_heads, + num_layers=self.num_hidden_layers, + num_decoder_layers=self.decoder_layers, + num_heads=self.num_attention_heads, + relative_attention_num_buckets=self.relative_attention_num_buckets, + dropout_rate=self.dropout_rate, + initializer_factor=self.initializer_factor, + eos_token_id=self.eos_token_id, + bos_token_id=self.pad_token_id, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + ) + + def get_config(self): + return TimesFMConfig( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + d_ff=self.d_ff, + d_kv=self.hidden_size // self.num_attention_heads, + num_layers=self.num_hidden_layers, + num_decoder_layers=self.decoder_layers, + num_heads=self.num_attention_heads, + relative_attention_num_buckets=self.relative_attention_num_buckets, + dropout_rate=self.dropout_rate, + initializer_factor=self.initializer_factor, + eos_token_id=self.eos_token_id, + bos_token_id=self.pad_token_id, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + ) + + def check_prepare_lm_labels_via_shift_left( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = TimesFMModel(config=config) + model.to(torch_device) + model.eval() + + # make sure that lm_labels are correctly padded from the right + lm_labels.masked_fill_((lm_labels == self.decoder_start_token_id), self.eos_token_id) + + # add casaul pad token mask + triangular_mask = torch.tril(lm_labels.new_ones(lm_labels.shape)).logical_not() + lm_labels.masked_fill_(triangular_mask, self.pad_token_id) + decoder_input_ids = model._shift_right(lm_labels) + + for i, (decoder_input_ids_slice, lm_labels_slice) in enumerate(zip(decoder_input_ids, lm_labels)): + # first item + self.parent.assertEqual(decoder_input_ids_slice[0].item(), self.decoder_start_token_id) + if i < decoder_input_ids_slice.shape[-1]: + if i < decoder_input_ids.shape[-1] - 1: + # items before diagonal + self.parent.assertListEqual( + decoder_input_ids_slice[1 : i + 1].tolist(), lm_labels_slice[:i].tolist() + ) + # pad items after diagonal + if i < decoder_input_ids.shape[-1] - 2: + self.parent.assertListEqual( + decoder_input_ids_slice[i + 2 :].tolist(), lm_labels_slice[i + 1 : -1].tolist() + ) + else: + # all items after square + self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist()) + + def create_and_check_model( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = TimesFMModel(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + decoder_output = result.last_hidden_state + decoder_past = result.past_key_values + encoder_output = result.encoder_last_hidden_state + + self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size)) + self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size)) + # There should be `num_layers` key value embeddings stored in decoder_past + self.parent.assertEqual(len(decoder_past), config.num_layers) + # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple + self.parent.assertEqual(len(decoder_past[0]), 4) + + def create_and_check_with_lm_head( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = TimesFMForConditionalGeneration(config=config).to(torch_device).eval() + outputs = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + labels=lm_labels, + ) + self.parent.assertEqual(len(outputs), 4) + self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size)) + self.parent.assertEqual(outputs["loss"].size(), ()) + + def create_and_check_with_sequence_classification_head( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + labels = torch.tensor([1] * self.batch_size, dtype=torch.long, device=torch_device) + model = TimesFMForSequenceClassification(config=config).to(torch_device).eval() + outputs = model( + input_ids=input_ids, + decoder_input_ids=input_ids, + labels=labels, + ) + # self.parent.assertEqual(len(outputs), 4) + self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, config.num_labels)) + self.parent.assertEqual(outputs["loss"].size(), ()) + + def create_and_check_decoder_model_past( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = TimesFMModel(config=config).get_decoder().to(torch_device).eval() + # first forward pass + outputs = model(input_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids) + outputs_no_past = model(input_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + output, past_key_values = outputs.to_tuple() + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + + output_from_no_past = model(next_input_ids)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_decoder_model_attention_mask_past( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = TimesFMModel(config=config).get_decoder() + model.to(torch_device) + model.eval() + + # create attention mask + attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + half_seq_length = input_ids.shape[-1] // 2 + attn_mask[:, half_seq_length:] = 0 + + # first forward pass + output, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True).to_tuple() + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) + input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens + + # append to next input_ids and attn_mask + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + attn_mask = torch.cat( + [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], + dim=1, + ) + + # get two different outputs + output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[ + "last_hidden_state" + ] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = TimesFMModel(config=config).get_decoder().to(torch_device).eval() + # first forward pass + outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + + output, past_key_values = outputs.to_tuple() + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([attention_mask, next_mask], dim=-1) + + output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"] + output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[ + "last_hidden_state" + ] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_generate_with_past_key_values( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = TimesFMForConditionalGeneration(config=config).to(torch_device).eval() + torch.manual_seed(0) + output_without_past_cache = model.generate( + input_ids[:1], num_beams=2, max_length=5, do_sample=True, use_cache=False + ) + torch.manual_seed(0) + output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True) + self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache)) + + def create_and_check_model_fp16_forward( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = TimesFMModel(config=config).to(torch_device).half().eval() + output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"] + self.parent.assertFalse(torch.isnan(output).any().item()) + + def create_and_check_encoder_decoder_shared_weights( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + for model_class in [TimesFMModel, TimesFMForConditionalGeneration]: + torch.manual_seed(0) + model = model_class(config=config).to(torch_device).eval() + # load state dict copies weights but does not tie them + model.encoder.load_state_dict(model.decoder.state_dict(), strict=False) + + torch.manual_seed(0) + tied_config = copy.deepcopy(config) + tied_config.tie_encoder_decoder = True + tied_model = model_class(config=tied_config).to(torch_device).eval() + + model_result = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + tied_model_result = tied_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + # check that models has less parameters + self.parent.assertLess( + sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) + ) + random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() + + # check that outputs are equal + self.parent.assertTrue( + torch.allclose( + model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4 + ) + ) + + # check that outputs after saving and loading are equal + with tempfile.TemporaryDirectory() as tmpdirname: + tied_model.save_pretrained(tmpdirname) + tied_model = model_class.from_pretrained(tmpdirname) + tied_model.to(torch_device) + tied_model.eval() + + # check that models has less parameters + self.parent.assertLess( + sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) + ) + random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() + + tied_model_result = tied_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + # check that outputs are equal + self.parent.assertTrue( + torch.allclose( + model_result[0][0, :, random_slice_idx], + tied_model_result[0][0, :, random_slice_idx], + atol=1e-4, + ) + ) + + def check_resize_embeddings_timesfm_v1_1( + self, + config, + ): + prev_vocab_size = config.vocab_size + + config.tie_word_embeddings = False + model = TimesFMForConditionalGeneration(config=config).to(torch_device).eval() + model.resize_token_embeddings(prev_vocab_size - 10) + + self.parent.assertEqual(model.get_input_embeddings().weight.shape[0], prev_vocab_size - 10) + self.parent.assertEqual(model.get_output_embeddings().weight.shape[0], prev_vocab_size - 10) + self.parent.assertEqual(model.config.vocab_size, prev_vocab_size - 10) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + "use_cache": False, + } + return config, inputs_dict + + +@require_torch +class TimesFMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = ( + (TimesFMModel, TimesFMForConditionalGeneration, TimesFMForSequenceClassification, TimesFMForQuestionAnswering) + if is_torch_available() + else () + ) + all_generative_model_classes = (TimesFMForConditionalGeneration,) if is_torch_available() else () + all_parallelizable_model_classes = (TimesFMModel, TimesFMForConditionalGeneration) if is_torch_available() else () + fx_compatible = False + test_pruning = False + test_resize_embeddings = True + test_model_parallel = True + is_encoder_decoder = True + # The small TimesFM model needs higher percentages for CPU/MP tests + model_split_percents = [0.5, 0.8, 0.9] + + def setUp(self): + self.model_tester = TimesFMModelTester(self) + self.config_tester = ConfigTester(self, config_class=TimesFMConfig, d_model=37) + + # TimesFMForSequenceClassification does not support inputs_embeds + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in (TimesFMModel, TimesFMForConditionalGeneration, TimesFMForQuestionAnswering): + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + + if not self.is_encoder_decoder: + input_ids = inputs["input_ids"] + del inputs["input_ids"] + else: + encoder_input_ids = inputs["input_ids"] + decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) + del inputs["input_ids"] + inputs.pop("decoder_input_ids", None) + + wte = model.get_input_embeddings() + if not self.is_encoder_decoder: + inputs["inputs_embeds"] = wte(input_ids) + else: + inputs["inputs_embeds"] = wte(encoder_input_ids) + inputs["decoder_inputs_embeds"] = wte(decoder_input_ids) + + with torch.no_grad(): + model(**inputs)[0] + + def test_config_and_model_silu_gated(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + config = config_and_inputs[0] + config.feed_forward_proj = "gated-silu" + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_with_lm_head(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_with_lm_head(*config_and_inputs) + + def test_with_sequence_classification_head(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_with_sequence_classification_head(*config_and_inputs) + + def test_decoder_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past(*config_and_inputs) + + def test_decoder_model_past_with_attn_mask(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) + + def test_decoder_model_past_with_3d_attn_mask(self): + ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) = self.model_tester.prepare_config_and_inputs() + + attention_mask = ids_tensor( + [self.model_tester.batch_size, self.model_tester.encoder_seq_length, self.model_tester.encoder_seq_length], + vocab_size=2, + ) + decoder_attention_mask = ids_tensor( + [self.model_tester.batch_size, self.model_tester.decoder_seq_length, self.model_tester.decoder_seq_length], + vocab_size=2, + ) + + self.model_tester.create_and_check_decoder_model_attention_mask_past( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) + + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + + def test_generate_with_past_key_values(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_generate_with_past_key_values(*config_and_inputs) + + def test_encoder_decoder_shared_weights(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_encoder_decoder_shared_weights(*config_and_inputs) + + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") + def test_model_fp16_forward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) + + def test_v1_1_resize_embeddings(self): + config = self.model_tester.prepare_config_and_inputs()[0] + self.model_tester.check_resize_embeddings_timesfm_v1_1(config) + + @slow + def test_model_from_pretrained(self): + model_name = "google/timesfm-1.0-200m" + model = TimesFMModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + @unittest.skip(reason="Test has a segmentation fault on torch 1.8.0") + def test_export_to_onnx(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + model = TimesFMModel(config_and_inputs[0]).to(torch_device) + with tempfile.TemporaryDirectory() as tmpdirname: + torch.onnx.export( + model, + (config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]), + f"{tmpdirname}/timesfm_test.onnx", + export_params=True, + opset_version=9, + input_names=["input_ids", "decoder_input_ids"], + ) + + def test_generate_with_head_masking(self): + attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] + config_and_inputs = self.model_tester.prepare_config_and_inputs() + config = config_and_inputs[0] + max_length = config_and_inputs[1].shape[-1] + 3 + model = TimesFMForConditionalGeneration(config).eval() + model.to(torch_device) + + head_masking = { + "head_mask": torch.zeros(config.num_layers, config.num_heads, device=torch_device), + "decoder_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device), + "cross_attn_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device), + } + + for attn_name, (name, mask) in zip(attention_names, head_masking.items()): + head_masks = {name: mask} + # Explicitly pass decoder_head_mask as it is required from TimesFM model when head_mask specified + if name == "head_mask": + head_masks["decoder_head_mask"] = torch.ones( + config.num_decoder_layers, config.num_heads, device=torch_device + ) + + out = model.generate( + config_and_inputs[1], + num_beams=1, + max_length=max_length, + output_attentions=True, + return_dict_in_generate=True, + **head_masks, + ) + # We check the state of decoder_attentions and cross_attentions just from the last step + attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1] + self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0) + + +class TimesFMEncoderOnlyModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + encoder_seq_length=7, + # For common tests + use_attention_mask=True, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + d_ff=37, + relative_attention_num_buckets=8, + is_training=False, + dropout_rate=0.1, + initializer_factor=0.002, + is_encoder_decoder=False, + eos_token_id=1, + pad_token_id=0, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.encoder_seq_length = encoder_seq_length + # For common tests + self.seq_length = self.encoder_seq_length + self.use_attention_mask = use_attention_mask + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.d_ff = d_ff + self.relative_attention_num_buckets = relative_attention_num_buckets + self.dropout_rate = dropout_rate + self.initializer_factor = initializer_factor + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.is_encoder_decoder = is_encoder_decoder + self.scope = None + self.is_training = is_training + + def get_large_model_config(self): + return TimesFMConfig.from_pretrained("google-timesfm/timesfm-base") + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) + + config = TimesFMConfig( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + d_ff=self.d_ff, + d_kv=self.hidden_size // self.num_attention_heads, + num_layers=self.num_hidden_layers, + num_heads=self.num_attention_heads, + relative_attention_num_buckets=self.relative_attention_num_buckets, + dropout_rate=self.dropout_rate, + initializer_factor=self.initializer_factor, + eos_token_id=self.eos_token_id, + bos_token_id=self.pad_token_id, + pad_token_id=self.pad_token_id, + is_encoder_decoder=self.is_encoder_decoder, + ) + + return ( + config, + input_ids, + attention_mask, + ) + + def create_and_check_model( + self, + config, + input_ids, + attention_mask, + ): + model = TimesFMEncoderModel(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids=input_ids, + attention_mask=attention_mask, + ) + result = model(input_ids=input_ids) + encoder_output = result.last_hidden_state + + self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size)) + + def create_and_check_model_fp16_forward( + self, + config, + input_ids, + attention_mask, + ): + model = TimesFMEncoderModel(config=config).to(torch_device).half().eval() + output = model(input_ids, attention_mask=attention_mask)["last_hidden_state"] + self.parent.assertFalse(torch.isnan(output).any().item()) + + def create_and_check_with_token_classification_head( + self, + config, + input_ids, + attention_mask, + ): + labels = torch.tensor([1] * self.seq_length * self.batch_size, dtype=torch.long, device=torch_device) + model = TimesFMForTokenClassification(config=config).to(torch_device).eval() + outputs = model( + input_ids=input_ids, + labels=labels, + attention_mask=attention_mask, + ) + self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.seq_length, config.num_labels)) + self.parent.assertEqual(outputs["loss"].size(), ()) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + attention_mask, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +class TimesFMEncoderOnlyModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (TimesFMEncoderModel, TimesFMForTokenClassification) if is_torch_available() else () + test_pruning = False + test_resize_embeddings = False + test_model_parallel = True + pipeline_model_mapping = ( + { + "token-classification": TimesFMForTokenClassification, + } + if is_torch_available() + else {} + ) + all_parallelizable_model_classes = (TimesFMEncoderModel,) if is_torch_available() else () + + def setUp(self): + self.model_tester = TimesFMEncoderOnlyModelTester(self) + self.config_tester = ConfigTester(self, config_class=TimesFMConfig, d_model=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") + def test_model_fp16_forward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) + + def test_with_token_classification_head(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_with_token_classification_head(*config_and_inputs) + + +def use_task_specific_params(model, task): + model.config.update(model.config.task_specific_params[task]) + + +@require_torch +@require_accelerate +@require_tokenizers +@slow +class TimesFMModelFp16Tests(unittest.TestCase): + def test_fp16_fp32_conversion(self): + r""" + A test to check whether the argument `keep_in_fp32_modules` correctly does its job + """ + orig_import = __import__ + accelerate_mock = unittest.mock.Mock() + + # mock import of accelerate + def import_accelerate_mock(name, *args, **kwargs): + if name == "accelerate": + if accelerate_available: + return accelerate_mock + else: + raise ImportError + return orig_import(name, *args, **kwargs) + + # Load without using `accelerate` + with unittest.mock.patch("builtins.__import__", side_effect=import_accelerate_mock): + accelerate_available = False + + model = TimesFMForConditionalGeneration.from_pretrained("google/timesfm-1.0-200m", torch_dtype=torch.float16) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16) + + # Load without in bf16 + model = TimesFMForConditionalGeneration.from_pretrained("google/timesfm-1.0-200m", torch_dtype=torch.bfloat16) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) + + # Load using `accelerate` in bf16 + model = TimesFMForConditionalGeneration.from_pretrained( + "google/timesfm-1.0-200m", torch_dtype=torch.bfloat16, device_map="auto" + ) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) + + # Load using `accelerate` in bf16 + model = TimesFMForConditionalGeneration.from_pretrained( + "google/timesfm-1.0-200m", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True + ) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) + + # Load without using `accelerate` + model = TimesFMForConditionalGeneration.from_pretrained( + "google/timesfm-1.0-200m", torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16) + + # Load using `accelerate` + model = TimesFMForConditionalGeneration.from_pretrained( + "google/timesfm-1.0-200m", torch_dtype=torch.float16, device_map="auto" + ) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16) + + +@require_torch +@require_sentencepiece +@require_tokenizers +class TimesFMModelIntegrationTests(unittest.TestCase): + @cached_property + def model(self): + return TimesFMForConditionalGeneration.from_pretrained("google-timesfm/timesfm-base").to(torch_device) + + @cached_property + def tokenizer(self): + return T5Tokenizer.from_pretrained("google-timesfm/timesfm-base") + + @slow + def test_torch_quant(self): + r""" + Test that a simple `torch.quantization.quantize_dynamic` call works on a TimesFM model. + """ + model_name = "google/flan-timesfm-small" + tokenizer = T5Tokenizer.from_pretrained(model_name) + model = TimesFMForConditionalGeneration.from_pretrained(model_name) + model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) + input_text = "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?" + input_ids = tokenizer(input_text, return_tensors="pt").input_ids + _ = model.generate(input_ids) + + @slow + def test_small_generation(self): + model = TimesFMForConditionalGeneration.from_pretrained("google/timesfm-1.0-200m").to(torch_device) + model.config.max_length = 8 + model.config.num_beams = 1 + model.config.do_sample = False + tokenizer = T5Tokenizer.from_pretrained("google/timesfm-1.0-200m") + + input_ids = tokenizer("summarize: Hello there", return_tensors="pt").input_ids.to(torch_device) + + sequences = model.generate(input_ids) + + output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] + self.assertTrue(output_str == "Hello there!") + + @slow + def test_small_integration_test(self): + """ + For comparision run: + >>> import timesfm # pip install timesfm==0.7.1 + >>> from timesfm.data.sentencepiece_vocabulary import SentencePieceVocabulary + + >>> path_to_mtf_small_timesfm_checkpoint = '' + >>> path_to_mtf_small_spm_model_path = '' + >>> timesfm_model = timesfm.models.MtfModel(model_dir=path_to_mtf_small_timesfm_checkpoint, batch_size=1, tpu=None) + >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) + >>> score = timesfm_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + """ + + model = TimesFMForConditionalGeneration.from_pretrained("google/timesfm-1.0-200m").to(torch_device) + tokenizer = T5Tokenizer.from_pretrained("google/timesfm-1.0-200m") + + input_ids = tokenizer("Hello there", return_tensors="pt").input_ids + labels = tokenizer("Hi I am", return_tensors="pt").input_ids + + loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss + mtf_score = -(labels.shape[-1] * loss.item()) + + EXPECTED_SCORE = -19.0845 + self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) + + @slow + def test_small_v1_1_integration_test(self): + """ + For comparision run: + >>> import timesfm # pip install timesfm==0.7.1 + >>> from timesfm.data.sentencepiece_vocabulary import SentencePieceVocabulary + + >>> path_to_mtf_small_timesfm_v1_1_checkpoint = '' + >>> path_to_mtf_small_spm_model_path = '' + >>> timesfm_model = timesfm.models.MtfModel(model_dir=path_to_mtf_small_timesfm_v1_1_checkpoint, batch_size=1, tpu=None) + >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) + >>> score = timesfm_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + """ + + model = TimesFMForConditionalGeneration.from_pretrained("google/timesfm-v1_1-small").to(torch_device) + tokenizer = T5Tokenizer.from_pretrained("google/timesfm-v1_1-small") + + input_ids = tokenizer("Hello there", return_tensors="pt").input_ids + labels = tokenizer("Hi I am", return_tensors="pt").input_ids + + loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss + mtf_score = -(labels.shape[-1] * loss.item()) + + EXPECTED_SCORE = -59.0293 + self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) + + @slow + def test_small_bytimesfm_integration_test(self): + """ + For comparision run: + >>> import timesfm # pip install timesfm==0.9.1 + + >>> path_to_bytimesfm_small_checkpoint = '' + >>> timesfm_model = timesfm.models.MtfModel(model_dir=path_to_tf_checkpoint, batch_size=1, tpu=None) + >>> vocab = timesfm.data.ByteVocabulary() + >>> score = timesfm_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + """ + + model = TimesFMForConditionalGeneration.from_pretrained("google/bytimesfm-small").to(torch_device) + tokenizer = ByT5Tokenizer.from_pretrained("google/bytimesfm-small") + + input_ids = tokenizer("Hello there", return_tensors="pt").input_ids + labels = tokenizer("Hi I am", return_tensors="pt").input_ids + + loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss + mtf_score = -(labels.shape[-1] * loss.item()) + + EXPECTED_SCORE = -60.7397 + self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) + + @slow + def test_summarization(self): + model = self.model + tok = self.tokenizer + + FRANCE_ARTICLE = ( # @noqa + "Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings" + " Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane." + ' Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation."' + ' He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s' + " comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video" + " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French" + " Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a" + " phone at the wreckage site. The two publications described the supposed video, but did not post it on" + " their websites. The publications said that they watched the video, which was found by a source close to" + " the investigation. \"One can hear cries of 'My God' in several languages,\" Paris Match reported." + ' "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the' + " cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the" + ' screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt,' + " editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said" + " the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman" + " in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the" + ' reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said,' + ' but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be' + " sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by" + " specialized technicians working hand-in-hand with investigators. But none of the cell phones found so" + " far have been sent to the institute, Menichini said. Asked whether staff involved in the search could" + ' have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin' + ' Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match' + ' are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered' + ' cell phones from the crash site after Bild and Paris Match published their reports. "That is something' + " we did not know before. ... Overall we can say many things of the investigation weren't revealed by the" + ' investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline' + " Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the" + " controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the" + ' French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of' + ' severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school' + " discovered in an internal investigation, Lufthansa said, included medical documents he submitted in" + " connection with resuming his flight training. The announcement indicates that Lufthansa, the parent" + " company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and" + " ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100%" + ' fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was' + " sharing the information and documents -- including training and medical records -- with public" + " prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the" + " past week to recover human remains and plane debris scattered across a steep mountainside. He saw the" + " crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash" + " site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late" + " Tuesday that no visible human remains were left at the site but recovery teams would keep searching." + " French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all" + " the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested." + " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said." + " Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew" + " on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with" + " the flight school during his training were among several developments as investigators continued to" + " delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa" + " spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his" + ' examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in' + " Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at" + " some point before his aviation career and underwent psychotherapy before he got his pilot's license." + " Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the" + " crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to" + " lose his pilot's license, a European government official briefed on the investigation told CNN on" + ' Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being' + " considered. Another source, a law enforcement official briefed on the investigation, also told CNN that" + " authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would" + " not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had" + " seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded" + " he had psychological issues, the European government official said. But no matter what details emerge" + " about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic" + ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact' + " that maybe they weren't going to keep doing their job and they're upset about that and so they're" + ' suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to' + " also take that rage and turn it outward on 149 other people who had nothing to do with the person's" + ' problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight' + " 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura" + " Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine" + " Amiel and Anna-Maja Rappard contributed to this report." + ) + SHORTER_ARTICLE = ( + "(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on" + " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The" + " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based." + " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its" + ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East' + ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the' + " situation in Palestinian territories, paving the way for possible war crimes investigations against" + " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and" + " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the" + " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a" + ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the' + ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an' + ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge' + " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the" + ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine' + " acquires all the rights as well as responsibilities that come with being a State Party to the Statute." + ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights' + ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should' + " immediately end their pressure, and countries that support universal acceptance of the court's treaty" + ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the' + " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's" + ' decision to join a treaty to which over 100 countries around the world are members." In January, when' + " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an" + ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"' + " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a" + ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in' + ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We' + ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"' + " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the" + ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the' + " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou" + ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war' + " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry" + " will include alleged war crimes committed since June. The International Criminal Court was set up in" + " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder" + " and Faith Karimi contributed to this report." + ) + IRAN_ARTICLE = ( + "(CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran" + " in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively" + " block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger." + " Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli" + " Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a" + " letter to the Iranian leadership warning them away from a deal. The debate that has already begun since" + " the announcement of the new framework will likely result in more heat than light. It will not be helped" + " by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: ." + " The most misleading assertion, despite universal rejection by experts, is that the negotiations'" + " objective at the outset was the total elimination of any nuclear program in Iran. That is the position" + " of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it" + " had been, there would have been no Iranian team at the negotiating table. Rather, the objective has" + " always been to structure an agreement or series of agreements so that Iran could not covertly develop a" + " nuclear arsenal before the United States and its allies could respond. The new framework has exceeded" + " expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by" + " two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another" + " dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite" + " sharp accusations by some in the United States and its allies, Iran denies having such a program, and" + " U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's" + " continued cooperation with International Atomic Energy Agency inspections is further evidence on this" + " point, and we'll know even more about Iran's program in the coming months and years because of the deal." + " In fact, the inspections provisions that are part of this agreement are designed to protect against any" + " covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that" + " the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter" + " warning that a deal might be killed by Congress or a future president). This of course is not the case." + " The talks were between Iran and the five permanent members of the U.N. Security Council (United States," + " United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has" + " played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement" + " reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran" + " and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement" + " contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the" + " case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased" + " or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes" + " Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear" + " sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going" + " forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such" + " a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the" + ' agreement should be a formal treaty requiring the Senate to "advise and consent." But the issue is not' + " suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New" + " START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement" + " with Iran will not be so balanced. The restrictions and obligations in the final framework agreement" + " will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove" + " most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally" + " some insist that any agreement must address Iranian missile programs, human rights violations or support" + " for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are" + " unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in" + " the negotiations would be a poison pill. This agreement should be judged on its merits and on how it" + " affects the security of our negotiating partners and allies, including Israel. Those judgments should be" + " fact-based, not based on questionable assertions or dubious assumptions." + ) + ARTICLE_SUBWAY = ( + "New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A" + " year later, she got married again in Westchester County, but to a different man and without divorcing" + " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos" + ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married' + " once more, this time in the Bronx. In an application for a marriage license, she stated it was her" + ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false' + ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage' + " license application, according to court documents. Prosecutors said the marriages were part of an" + " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to" + " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was" + " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New" + " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total," + " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All" + " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be" + " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors" + " said the immigration scam involved some of her husbands, who filed for permanent residence status" + " shortly after the marriages. Any divorces happened only after such filings were approved. It was" + " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District" + " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's" + ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,' + " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his" + " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces" + " up to four years in prison. Her next court appearance is scheduled for May 18." + ) + + expected_summaries = [ + 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a' + " cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one" + " magazine says .", + "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a" + " preliminary examination into the situation in the occupied Palestinian territory . as members of the" + " court, Palestinians may be subject to counter-charges as well .", + "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller:" + " the debate that has already begun since the announcement of the new framework will likely result in more" + " heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and" + " implement a rigorous inspection regime .", + "prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two" + ' criminal counts of "offering a false instrument for filing in the first degree" she has been married 10' + " times, with nine of her marriages occurring between 1999 and 2002 .", + ] + + use_task_specific_params(model, "summarization") + + dct = tok( + [model.config.prefix + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]], + padding="max_length", + truncation=True, + return_tensors="pt", + ).to(torch_device) + self.assertEqual(512, dct["input_ids"].shape[1]) + + hypotheses_batch = model.generate( + **dct, + num_beams=4, + length_penalty=2.0, + max_length=142, + min_length=56, + no_repeat_ngram_size=3, + do_sample=False, + early_stopping=True, + ) + + decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False) + self.assertListEqual( + expected_summaries, + decoded, + ) + + @slow + def test_translation_en_to_de(self): + model = self.model + tok = self.tokenizer + use_task_specific_params(model, "translation_en_to_de") + + en_text = '"Luigi often said to me that he never wanted the brothers to end up in court", she wrote.' + expected_translation = ( + '"Luigi sagte mir oft, dass er nie wollte, dass die Brüder am Gericht sitzen", schrieb sie.' + ) + + input_ids = tok.encode(model.config.prefix + en_text, return_tensors="pt") + input_ids = input_ids.to(torch_device) + output = model.generate(input_ids) + translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) + self.assertEqual(translation, expected_translation) + + @slow + def test_translation_en_to_fr(self): + model = self.model # google-timesfm/timesfm-base + tok = self.tokenizer + use_task_specific_params(model, "translation_en_to_fr") + + en_text = ( + ' This image section from an infrared recording by the Spitzer telescope shows a "family portrait" of' + " countless generations of stars: the oldest stars are seen as blue dots. " + ) + + input_ids = tok.encode(model.config.prefix + en_text, return_tensors="pt") + input_ids = input_ids.to(torch_device) + + output = model.generate( + input_ids=input_ids, + num_beams=4, + length_penalty=2.0, + max_length=100, + no_repeat_ngram_size=3, + do_sample=False, + early_stopping=True, + ) + translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) + new_truncated_translation = ( + "Cette section d'images provenant de l'enregistrement infrarouge effectué par le télescope Spitzer montre " + "un " + "« portrait familial » de générations innombrables d’étoiles : les plus anciennes sont observées " + "sous forme " + "de points bleus." + ) + + self.assertEqual(translation, new_truncated_translation) + + @slow + def test_translation_en_to_ro(self): + model = self.model + tok = self.tokenizer + use_task_specific_params(model, "translation_en_to_ro") + en_text = "Taco Bell said it plans to add 2,000 locations in the US by 2022." + expected_translation = "Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022." + + inputs = tok(model.config.prefix + en_text, return_tensors="pt").to(torch_device) + output = model.generate(**inputs) + translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) + self.assertEqual(translation, expected_translation) + + @slow + def test_contrastive_search_timesfm(self): + article = ( + " New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A" + " year later, she got married again in Westchester County, but to a different man and without divorcing" + " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos" + ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married' + " once more, this time in the Bronx. In an application for a marriage license, she stated it was her" + ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false' + ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage' + " license application, according to court documents. Prosecutors said the marriages were part of an" + " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to" + " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was" + " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New" + " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total," + " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All" + " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be" + " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors" + " said the immigration scam involved some of her husbands, who filed for permanent residence status" + " shortly after the marriages. Any divorces happened only after such filings were approved. It was" + " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District" + " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's" + ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,' + " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his" + " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces" + " up to four years in prison. Her next court appearance is scheduled for May 18." + ) + article = "summarize: " + article.strip() + timesfm_tokenizer = AutoTokenizer.from_pretrained("flax-community/timesfm-base-cnn-dm") + timesfm_model = TimesFMForConditionalGeneration.from_pretrained("flax-community/timesfm-base-cnn-dm").to(torch_device) + input_ids = timesfm_tokenizer( + article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt" + ).input_ids.to(torch_device) + + outputs = timesfm_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64) + generated_text = timesfm_tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual( + generated_text, + [ + "Liana Barrientos has been married 10 times, nine of them in the Bronx. Her husbands filed for " + "permanent residence after the marriages, prosecutors say." + ], + ) + + +@require_torch +class TestAsymmetricTimesFM(unittest.TestCase): + def build_model_and_check_forward_pass(self, **kwargs): + tester = TimesFMModelTester(self, **kwargs) + config, *inputs = tester.prepare_config_and_inputs() + ( + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) = inputs + model = TimesFMForConditionalGeneration(config=config).to(torch_device).eval() + outputs = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + labels=lm_labels, + ) + # outputs = model(*inputs) + assert len(outputs) == 4 + assert outputs["logits"].size() == (tester.batch_size, tester.decoder_seq_length, tester.vocab_size) + assert outputs["loss"].size() == () + return model + + def test_small_decoder(self): + # num_hidden_layers is passed to TimesFMConfig as num_layers + model = self.build_model_and_check_forward_pass(decoder_layers=1, num_hidden_layers=2) + assert len(model.encoder.block) == 2 + assert len(model.decoder.block) == 1 + + def test_defaulting_to_symmetry(self): + # num_hidden_layers is passed to TimesFMConfig as num_layers + model = self.build_model_and_check_forward_pass(num_hidden_layers=2) + assert len(model.decoder.block) == len(model.encoder.block) == 2 From b7b4b916821792c45d28033cd3f34503fb4825cf Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 27 Aug 2024 23:25:52 +0200 Subject: [PATCH 002/242] added config and attention layers --- .../models/timesfm/configuration_timesfm.py | 15 ++--- .../models/timesfm/modeling_timesfm.py | 67 +++++++------------ 2 files changed, 29 insertions(+), 53 deletions(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 065f779b557c..ad66b752b3b3 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -35,9 +35,6 @@ class TimesFMConfig(PretrainedConfig): documentation from [`PretrainedConfig`] for more information. Arguments: - vocab_size (`int`, *optional*, defaults to 32128): - Vocabulary size of the TimesFM model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`TimesFMModel`] or [`TFTimesFMModel`]. d_model (`int`, *optional*, defaults to 512): Size of the encoder layers and the pooler layer. d_kv (`int`, *optional*, defaults to 64): @@ -77,13 +74,12 @@ class TimesFMConfig(PretrainedConfig): def __init__( self, - vocab_size=32128, - d_model=512, - d_kv=64, - d_ff=2048, - num_layers=6, + d_model=1280, + d_kv=80, + d_ff=1280, + num_layers=20, num_decoder_layers=None, - num_heads=8, + num_heads=16, relative_attention_num_buckets=32, relative_attention_max_distance=128, dropout_rate=0.1, @@ -97,7 +93,6 @@ def __init__( classifier_dropout=0.0, **kwargs, ): - self.vocab_size = vocab_size self.d_model = d_model self.d_kv = d_kv self.d_ff = d_ff diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index ea6daa33ac6b..ee4701cdaded 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -21,6 +21,7 @@ from typing import List, Optional, Tuple, Union import torch +import torch.nn.functional as F from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -270,12 +271,11 @@ def forward(self, hidden_states): ALL_LAYERNORM_LAYERS.append(TimesFMLayerNorm) -# Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->TimesFM class TimesFMDenseActDense(nn.Module): def __init__(self, config: TimesFMConfig): super().__init__() - self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) - self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.wi = nn.Linear(config.d_model, config.d_ff, bias=True) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=True) self.dropout = nn.Dropout(config.dropout_rate) self.act = ACT2FN[config.dense_act_fn] @@ -293,56 +293,35 @@ def forward(self, hidden_states): return hidden_states -# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->TimesFM,t5->timesfm -class TimesFMDenseGatedActDense(nn.Module): +class TimesFMLayerFF(nn.Module): def __init__(self, config: TimesFMConfig): super().__init__() - self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) - self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) - self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + + self.DenseReluDense = TimesFMDenseActDense(config) + self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) - self.act = ACT2FN[config.dense_act_fn] def forward(self, hidden_states): - hidden_gelu = self.act(self.wi_0(hidden_states)) - hidden_linear = self.wi_1(hidden_states) - hidden_states = hidden_gelu * hidden_linear - hidden_states = self.dropout(hidden_states) - - # To make 8bit quantization work for google/flan-timesfm-xxl, self.wo is kept in float32. - # See https://github.com/huggingface/transformers/issues/20287 - # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` - if ( - isinstance(self.wo.weight, torch.Tensor) - and hidden_states.dtype != self.wo.weight.dtype - and self.wo.weight.dtype != torch.int8 - ): - hidden_states = hidden_states.to(self.wo.weight.dtype) - - hidden_states = self.wo(hidden_states) + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) return hidden_states -# Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->TimesFM -class TimesFMLayerFF(nn.Module): +class TimesFMPerHeadDimScale(nn.Module): def __init__(self, config: TimesFMConfig): super().__init__() - if config.is_gated_act: - self.DenseReluDense = TimesFMDenseGatedActDense(config) - else: - self.DenseReluDense = TimesFMDenseActDense(config) - self.layer_norm = TimesFMLayerNorm(config.d_model, eps=config.layer_norm_epsilon) - self.dropout = nn.Dropout(config.dropout_rate) + self.dim = config.d_model // config.num_heads + self.scale = nn.Parameter(torch.zeros(self.dim)) def forward(self, hidden_states): - forwarded_states = self.layer_norm(hidden_states) - forwarded_states = self.DenseReluDense(forwarded_states) - hidden_states = hidden_states + self.dropout(forwarded_states) - return hidden_states + r_softplus_0 = 1.442695041 + scale = r_softplus_0 / math.sqrt(self.dim) + scale *= F.softplus(self.scale) + return hidden_states * scale -# Copied from transformers.models.t5.modeling_t5.T5Attention with T5->TimesFM class TimesFMAttention(nn.Module): def __init__(self, config: TimesFMConfig, has_relative_attention_bias=False): super().__init__() @@ -357,10 +336,11 @@ def __init__(self, config: TimesFMConfig, has_relative_attention_bias=False): self.inner_dim = self.n_heads * self.key_value_proj_dim # Mesh TensorFlow initialization to avoid scaling before softmax - self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) - self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) - self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) - self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + self.q = nn.Linear(self.d_model, self.inner_dim, bias=True) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=True) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=True) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=True) + self.per_head_dim_scale = TimesFMPerHeadDimScale(config) if self.has_relative_attention_bias: self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) @@ -515,7 +495,8 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): return hidden_states # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + unscaled_query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + query_states = self.per_head_dim_scale(unscaled_query_states) # get key/value states key_states = project( From b604a99a8b3ed9c1fa4fd8799a4eab7abe2091a6 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 27 Aug 2024 23:33:36 +0200 Subject: [PATCH 003/242] add TimesFMPositionalEmbedding --- .../models/timesfm/modeling_timesfm.py | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index ee4701cdaded..4f45d4f4e8c8 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -271,6 +271,59 @@ def forward(self, hidden_states): ALL_LAYERNORM_LAYERS.append(TimesFMLayerNorm) +class TimesFMPositionalEmbedding(nn.Module): + """Generates position embedding for a given 1-d sequence. + + Attributes: + min_timescale: Start of the geometric index. Determines the periodicity of + the added signal. + max_timescale: End of the geometric index. Determines the frequency of the + added signal. + embedding_dims: Dimension of the embedding to be generated. + """ + + def __init__(self, min_timescale=1, max_timescale=10000, embedding_dims=0): + super().__init__() + + self.min_timescale = min_timescale + self.max_timescale = max_timescale + self.embedding_dims = embedding_dims + + def forward(self, seq_length=None, position=None): + """Generates a tensor of sinusoids with different frequencies. + + Args: + seq_length: an optional Python int defining the output sequence length. + if the `position` argument is specified. + position: [B, seq_length], optional position for each token in the + sequence, only required when the sequence is packed. + + Returns: + [B, seqlen, D] if `position` is specified, else [1, seqlen, D] + """ + if position is None: + if seq_length is None: + raise ValueError("If position is None, seq_length should be specified.") + # [1, seqlen] + position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0) + else: + if position.ndim != 2: + raise ValueError(f"position should have 2 dimensions, got {position.ndim}") + + num_timescales = self.embedding_dims // 2 + log_timescale_increment = math.log(float(self.max_timescale) / float(self.min_timescale)) / max( + torch.tensor(num_timescales, dtype=torch.float32) - 1, 1 + ) + inv_timescales = self.min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(2) * inv_timescales.unsqueeze(0).unsqueeze(0) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2).type(torch.float32) + + signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) + return signal + + class TimesFMDenseActDense(nn.Module): def __init__(self, config: TimesFMConfig): super().__init__() From 8e6e47caa3a564bd5c8b5778837bb6b4907f2592 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 31 Aug 2024 11:31:20 +0200 Subject: [PATCH 004/242] calcuate scale_factor once --- .../models/timesfm/modeling_timesfm.py | 23 ++++++------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 4f45d4f4e8c8..8b338d7ae153 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -364,14 +364,13 @@ def forward(self, hidden_states): class TimesFMPerHeadDimScale(nn.Module): def __init__(self, config: TimesFMConfig): super().__init__() - - self.dim = config.d_model // config.num_heads - self.scale = nn.Parameter(torch.zeros(self.dim)) + dim = config.d_model // config.num_heads + r_softplus_0 = 1.442695041 + self.scale_factor = r_softplus_0 / math.sqrt(dim) + self.scale = nn.Parameter(torch.empty(self.dim)) def forward(self, hidden_states): - r_softplus_0 = 1.442695041 - scale = r_softplus_0 / math.sqrt(self.dim) - scale *= F.softplus(self.scale) + scale = self.scale_factor * F.softplus(self.scale) return hidden_states * scale @@ -890,16 +889,8 @@ def _init_weights(self, module): module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: module.wo.bias.data.zero_() - elif isinstance(module, TimesFMDenseGatedActDense): - module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) - if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: - module.wi_0.bias.data.zero_() - module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) - if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: - module.wi_1.bias.data.zero_() - module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) - if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.data.zero_() + elif isinstance(module, TimesFMPerHeadDimScale): + module.scale.data.zero_() elif isinstance(module, TimesFMAttention): # Mesh TensorFlow attention initialization to avoid scaling before softmax # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 From 56718a81bcebbc9bdcad886fd946b0ed26a15c1e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 31 Aug 2024 12:08:30 +0200 Subject: [PATCH 005/242] add more configs and TimesFMResidualBlock --- .../models/timesfm/configuration_timesfm.py | 28 +++++++++++++- .../models/timesfm/modeling_timesfm.py | 37 ++++++++++++++++--- 2 files changed, 58 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index ad66b752b3b3..933842f977ad 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -14,7 +14,7 @@ # limitations under the License. """TimesFM model configuration""" -from typing import Mapping +from typing import List, Mapping from ...configuration_utils import PretrainedConfig from ...onnx import OnnxSeq2SeqConfigWithPast @@ -35,12 +35,24 @@ class TimesFMConfig(PretrainedConfig): documentation from [`PretrainedConfig`] for more information. Arguments: + patch_len (`int`, *optional*, defaults to 32): + The length of each patch in the sequence. + horizon_len (`int`, *optional*, defaults to 128): + The length of the prediction horizon. + quantiles (`List[float]`, *optional*, defaults to `[0.1, 0.25, 0.5, 0.75, 0.9]`): + The quantiles to predict. + pad_val (`float`, *optional*, defaults to 1123581321.0): + The value used to pad the predictions. + tolerance (`float`, *optional*, defaults to 1e-6): + The tolerance for the quantile loss. + freq_size (`int`, *optional*, defaults to 3): + The number of frequency embeddings. d_model (`int`, *optional*, defaults to 512): Size of the encoder layers and the pooler layer. d_kv (`int`, *optional*, defaults to 64): Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will be defined as `num_heads * d_kv`. - d_ff (`int`, *optional*, defaults to 2048): + d_ff (`int`, *optional*, defaults to 1280): Size of the intermediate feed forward layer in each `TimesFMBlock`. num_layers (`int`, *optional*, defaults to 6): Number of hidden layers in the Transformer encoder. @@ -74,6 +86,12 @@ class TimesFMConfig(PretrainedConfig): def __init__( self, + patch_len: int = 32, + horizon_len: int = 128, + quantiles: List[float] = [0.1, 0.25, 0.5, 0.75, 0.9], + pad_val: float = 1123581321.0, + tolerance: float = 1e-6, + freq_size=3, d_model=1280, d_kv=80, d_ff=1280, @@ -93,6 +111,12 @@ def __init__( classifier_dropout=0.0, **kwargs, ): + self.patch_len = patch_len + self.horizon_len = horizon_len + self.quantiles = quantiles + self.pad_val = pad_val + self.tolerance = tolerance + self.freq_size = freq_size self.d_model = d_model self.d_kv = d_kv self.d_ff = d_ff diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 8b338d7ae153..61010b17929d 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -271,6 +271,24 @@ def forward(self, hidden_states): ALL_LAYERNORM_LAYERS.append(TimesFMLayerNorm) +class TimesFMResidualBlock(nn.Module): + def __init__(self, input_dims, hidden_dims, output_dims, dropout=0.1): + super().__init__() + + self.hidden_layer = nn.Sequential(nn.Linear(input_dims, hidden_dims), nn.SiLU()) + self.output_layer = nn.Linear(hidden_dims, output_dims) + self.residual_layer = nn.Linear(input_dims, output_dims) + self.dropout = nn.Dropout(dropout) + + def forward(self, inputs): + hidden = self.hidden_layer(inputs) + output = self.output_layer(hidden) + output = self.dropout(output) + residual = self.residual_layer(inputs) + + return output + residual + + class TimesFMPositionalEmbedding(nn.Module): """Generates position embedding for a given 1-d sequence. @@ -932,7 +950,6 @@ def _shift_right(self, input_ids): return shifted_input_ids -# Copied from transformers.models.t5.modeling_t5.T5Stack with T5->TimesFM class TimesFMStack(TimesFMPreTrainedModel): def __init__(self, config, embed_tokens=None): super().__init__(config) @@ -943,7 +960,6 @@ def __init__(self, config, embed_tokens=None): self.block = nn.ModuleList( [TimesFMBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] ) - self.final_layer_norm = TimesFMLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) # Initialize weights and apply final processing @@ -1182,7 +1198,6 @@ def forward( if i == v[-1] and "cuda:" + str(k) != self.last_device: hidden_states = hidden_states.to("cuda:" + str(k + 1)) - hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) # Add last layer @@ -1382,7 +1397,13 @@ class TimesFMModel(TimesFMPreTrainedModel): def __init__(self, config: TimesFMConfig): super().__init__(config) - self.shared = nn.Embedding(config.vocab_size, config.d_model) + self.freq_emb = nn.Embedding(config.freq_size, config.d_model) + self.horizon_ff_layer = TimesFMResidualBlock( + input_dims=config.hidden_size, + output_dims=config.horizon_len * (1 + len(config.quantiles)), + hidden_dims=config.d_ff, + dropout=config.dropout_rate, + ) encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False @@ -1909,7 +1930,13 @@ class TimesFMEncoderModel(TimesFMPreTrainedModel): def __init__(self, config: TimesFMConfig): super().__init__(config) - self.shared = nn.Embedding(config.vocab_size, config.d_model) + self.freq_emb = nn.Embedding(config.freq_size, config.d_model) + self.horizon_ff_layer = TimesFMResidualBlock( + input_dims=config.d_model, + output_dims=config.horizon_len * (1 + len(config.quantiles)), + hidden_dims=config.d_ff, + dropout=config.dropout_rate, + ) encoder_config = copy.deepcopy(config) encoder_config.use_cache = False From bd904c83426ae1884e6cd1c23a1924b1f3ef5a67 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 31 Aug 2024 22:02:11 +0200 Subject: [PATCH 006/242] fix input_dims --- .../models/timesfm/modeling_timesfm.py | 105 +----------------- 1 file changed, 2 insertions(+), 103 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 61010b17929d..1a2d93b67694 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -951,10 +951,9 @@ def _shift_right(self, input_ids): class TimesFMStack(TimesFMPreTrainedModel): - def __init__(self, config, embed_tokens=None): + def __init__(self, config): super().__init__(config) - self.embed_tokens = embed_tokens self.is_decoder = config.is_decoder self.block = nn.ModuleList( @@ -965,60 +964,9 @@ def __init__(self, config, embed_tokens=None): # Initialize weights and apply final processing self.post_init() # Model parallel - self.model_parallel = False self.device_map = None self.gradient_checkpointing = False - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def parallelize(self, device_map=None): - warnings.warn( - "`TimesFMStack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model" - " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" - " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0," - " 'block.1': 1, ...}", - FutureWarning, - ) - # Check validity of device_map - self.device_map = ( - get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map - ) - assert_device_map(self.device_map, len(self.block)) - self.model_parallel = True - self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) - self.last_device = "cuda:" + str(max(self.device_map.keys())) - # Load onto devices - for k, v in self.device_map.items(): - for layer in v: - cuda_device = "cuda:" + str(k) - self.block[layer] = self.block[layer].to(cuda_device) - - # Set embed_tokens to first layer - self.embed_tokens = self.embed_tokens.to(self.first_device) - # Set final layer norm to last device - self.final_layer_norm = self.final_layer_norm.to(self.last_device) - - @add_start_docstrings(DEPARALLELIZE_DOCSTRING) - def deparallelize(self): - warnings.warn( - "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", - FutureWarning, - ) - self.model_parallel = False - self.device_map = None - self.first_device = "cpu" - self.last_device = "cpu" - for i in range(len(self.block)): - self.block[i] = self.block[i].to("cpu") - self.embed_tokens = self.embed_tokens.to("cpu") - self.final_layer_norm = self.final_layer_norm.to("cpu") - torch.cuda.empty_cache() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, new_embeddings): - self.embed_tokens = new_embeddings - def forward( self, input_ids=None, @@ -1399,7 +1347,7 @@ def __init__(self, config: TimesFMConfig): super().__init__(config) self.freq_emb = nn.Embedding(config.freq_size, config.d_model) self.horizon_ff_layer = TimesFMResidualBlock( - input_dims=config.hidden_size, + input_dims=config.d_model, output_dims=config.horizon_len * (1 + len(config.quantiles)), hidden_dims=config.d_ff, dropout=config.dropout_rate, @@ -1950,58 +1898,9 @@ def __init__(self, config: TimesFMConfig): self.model_parallel = False self.device_map = None - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def parallelize(self, device_map=None): - warnings.warn( - "`TimesFMEncoderModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load" - " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" - " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0," - " 'block.1': 1, ...}", - FutureWarning, - ) - self.device_map = ( - get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) - if device_map is None - else device_map - ) - assert_device_map(self.device_map, len(self.encoder.block)) - self.encoder.parallelize(self.device_map) - self.model_parallel = True - - @add_start_docstrings(DEPARALLELIZE_DOCSTRING) - def deparallelize(self): - warnings.warn( - "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", - FutureWarning, - ) - self.encoder.deparallelize() - self.encoder = self.encoder.to("cpu") - self.model_parallel = False - self.device_map = None - torch.cuda.empty_cache() - - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, new_embeddings): - self.shared = new_embeddings - self.encoder.set_input_embeddings(new_embeddings) - - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) - def get_encoder(self): return self.encoder - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads) - @add_start_docstrings_to_model_forward(TIMESFM_ENCODER_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) def forward( From 45782ed3d1f3dc062a14da2d4196fa8a48bc9979 Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Thu, 5 Sep 2024 14:41:53 -0700 Subject: [PATCH 007/242] standardize code format with black --- src/transformers/models/timesfm/__init__.py | 4 +- .../models/timesfm/configuration_timesfm.py | 16 +- ...mesfm_original_tf_checkpoint_to_pytorch.py | 26 +- .../convert_timesfmx_checkpoint_to_flax.py | 320 +++++++------ .../convert_timesfmx_checkpoint_to_pytorch.py | 79 +++- .../models/timesfm/modeling_timesfm.py | 434 +++++++++++++----- 6 files changed, 619 insertions(+), 260 deletions(-) diff --git a/src/transformers/models/timesfm/__init__.py b/src/transformers/models/timesfm/__init__.py index 7398dccbda88..1abef3d3e175 100644 --- a/src/transformers/models/timesfm/__init__.py +++ b/src/transformers/models/timesfm/__init__.py @@ -63,4 +63,6 @@ else: import sys - sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + sys.modules[__name__] = _LazyModule( + __name__, globals()["__file__"], _import_structure, module_spec=__spec__ + ) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 933842f977ad..cdee64d7f377 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -82,7 +82,11 @@ class TimesFMConfig(PretrainedConfig): model_type = "timesfm" keys_to_ignore_at_inference = ["past_key_values"] - attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + } def __init__( self, @@ -167,10 +171,16 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: if self.use_past: common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence" common_inputs["decoder_input_ids"] = {0: "batch"} - common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} + common_inputs["decoder_attention_mask"] = { + 0: "batch", + 1: "past_decoder_sequence + sequence", + } else: common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} - common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} + common_inputs["decoder_attention_mask"] = { + 0: "batch", + 1: "decoder_sequence", + } if self.use_past: self.fill_with_past_key_values_(common_inputs, direction="inputs") diff --git a/src/transformers/models/timesfm/convert_timesfm_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/timesfm/convert_timesfm_original_tf_checkpoint_to_pytorch.py index aa66a8392d4f..b1ce727cac0c 100644 --- a/src/transformers/models/timesfm/convert_timesfm_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/timesfm/convert_timesfm_original_tf_checkpoint_to_pytorch.py @@ -16,14 +16,20 @@ import argparse -from transformers import TimesFMConfig, TimesFMForConditionalGeneration, load_tf_weights_in_timesfm +from transformers import ( + TimesFMConfig, + TimesFMForConditionalGeneration, + load_tf_weights_in_timesfm, +) from transformers.utils import logging logging.set_verbosity_info() -def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): +def convert_tf_checkpoint_to_pytorch( + tf_checkpoint_path, config_file, pytorch_dump_path +): # Initialise PyTorch model config = TimesFMConfig.from_json_file(config_file) print(f"Building PyTorch model from configuration: {config}") @@ -41,7 +47,11 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du parser = argparse.ArgumentParser() # Required parameters parser.add_argument( - "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + "--tf_checkpoint_path", + default=None, + type=str, + required=True, + help="Path to the TensorFlow checkpoint path.", ) parser.add_argument( "--config_file", @@ -53,7 +63,13 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du ), ) parser.add_argument( - "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + "--pytorch_dump_path", + default=None, + type=str, + required=True, + help="Path to the output PyTorch model.", ) args = parser.parse_args() - convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) + convert_tf_checkpoint_to_pytorch( + args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path + ) diff --git a/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_flax.py b/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_flax.py index 98570e22876e..f9468ffb84c6 100644 --- a/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_flax.py +++ b/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_flax.py @@ -22,7 +22,9 @@ from transformers import FlaxTimesFMForConditionalGeneration, TimesFMConfig -def convert_timesfmx_checkpoint_to_flax(timesfmx_checkpoint_path, config_name, flax_dump_folder_path): +def convert_timesfmx_checkpoint_to_flax( + timesfmx_checkpoint_path, config_name, flax_dump_folder_path +): config = TimesFMConfig.from_pretrained(config_name) flax_model = FlaxTimesFMForConditionalGeneration(config=config) timesfmx_model = checkpoints.load_timesfmx_checkpoint(timesfmx_checkpoint_path) @@ -34,67 +36,89 @@ def convert_timesfmx_checkpoint_to_flax(timesfmx_checkpoint_path, config_name, f layer_name = f"layers_{str(layer_index)}" # Self-Attention - timesfmx_attention_key = timesfmx_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"] - timesfmx_attention_out = timesfmx_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"] - timesfmx_attention_query = timesfmx_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"] - timesfmx_attention_value = timesfmx_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"] + timesfmx_attention_key = timesfmx_model["target"]["encoder"][layer_name][ + "attention" + ]["key"]["kernel"] + timesfmx_attention_out = timesfmx_model["target"]["encoder"][layer_name][ + "attention" + ]["out"]["kernel"] + timesfmx_attention_query = timesfmx_model["target"]["encoder"][layer_name][ + "attention" + ]["query"]["kernel"] + timesfmx_attention_value = timesfmx_model["target"]["encoder"][layer_name][ + "attention" + ]["value"]["kernel"] # Layer Normalization - timesfmx_attention_layer_norm = timesfmx_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"] + timesfmx_attention_layer_norm = timesfmx_model["target"]["encoder"][layer_name][ + "pre_attention_layer_norm" + ]["scale"] if split_mlp_wi: - timesfmx_mlp_wi_0 = timesfmx_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"] - timesfmx_mlp_wi_1 = timesfmx_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"] + timesfmx_mlp_wi_0 = timesfmx_model["target"]["encoder"][layer_name]["mlp"][ + "wi_0" + ]["kernel"] + timesfmx_mlp_wi_1 = timesfmx_model["target"]["encoder"][layer_name]["mlp"][ + "wi_1" + ]["kernel"] else: - timesfmx_mlp_wi = timesfmx_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"] + timesfmx_mlp_wi = timesfmx_model["target"]["encoder"][layer_name]["mlp"][ + "wi" + ]["kernel"] - timesfmx_mlp_wo = timesfmx_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"] + timesfmx_mlp_wo = timesfmx_model["target"]["encoder"][layer_name]["mlp"]["wo"][ + "kernel" + ] # Layer Normalization - timesfmx_mlp_layer_norm = timesfmx_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"] + timesfmx_mlp_layer_norm = timesfmx_model["target"]["encoder"][layer_name][ + "pre_mlp_layer_norm" + ]["scale"] # Assigning - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = ( - timesfmx_attention_key - ) - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = ( - timesfmx_attention_out - ) - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = ( - timesfmx_attention_query - ) - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = ( - timesfmx_attention_value - ) - - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = ( - timesfmx_attention_layer_norm - ) + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ + "SelfAttention" + ]["k"]["kernel"] = timesfmx_attention_key + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ + "SelfAttention" + ]["o"]["kernel"] = timesfmx_attention_out + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ + "SelfAttention" + ]["q"]["kernel"] = timesfmx_attention_query + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ + "SelfAttention" + ]["v"]["kernel"] = timesfmx_attention_value + + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ + "layer_norm" + ]["weight"] = timesfmx_attention_layer_norm if split_mlp_wi: - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_0"][ - "kernel" - ] = timesfmx_mlp_wi_0 - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_1"][ - "kernel" - ] = timesfmx_mlp_wi_1 + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"][ + "DenseReluDense" + ]["wi_0"]["kernel"] = timesfmx_mlp_wi_0 + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"][ + "DenseReluDense" + ]["wi_1"]["kernel"] = timesfmx_mlp_wi_1 else: - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi"]["kernel"] = ( - timesfmx_mlp_wi - ) + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"][ + "DenseReluDense" + ]["wi"]["kernel"] = timesfmx_mlp_wi - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wo"]["kernel"] = ( - timesfmx_mlp_wo - ) - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = ( - timesfmx_mlp_layer_norm - ) + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"][ + "DenseReluDense" + ]["wo"]["kernel"] = timesfmx_mlp_wo + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"][ + "layer_norm" + ]["weight"] = timesfmx_mlp_layer_norm # Only for layer 0: - timesfmx_encoder_rel_embedding = timesfmx_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T - flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][ - "embedding" - ] = timesfmx_encoder_rel_embedding + timesfmx_encoder_rel_embedding = timesfmx_model["target"]["encoder"]["relpos_bias"][ + "rel_embedding" + ].T + flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"][ + "relative_attention_bias" + ]["embedding"] = timesfmx_encoder_rel_embedding # Assigning timesfmx_encoder_norm = timesfmx_model["target"]["encoder"]["encoder_norm"]["scale"] @@ -105,109 +129,131 @@ def convert_timesfmx_checkpoint_to_flax(timesfmx_checkpoint_path, config_name, f layer_name = f"layers_{str(layer_index)}" # Self-Attention - timesfmx_attention_key = timesfmx_model["target"]["decoder"][layer_name]["self_attention"]["key"]["kernel"] - timesfmx_attention_out = timesfmx_model["target"]["decoder"][layer_name]["self_attention"]["out"]["kernel"] - timesfmx_attention_query = timesfmx_model["target"]["decoder"][layer_name]["self_attention"]["query"]["kernel"] - timesfmx_attention_value = timesfmx_model["target"]["decoder"][layer_name]["self_attention"]["value"]["kernel"] + timesfmx_attention_key = timesfmx_model["target"]["decoder"][layer_name][ + "self_attention" + ]["key"]["kernel"] + timesfmx_attention_out = timesfmx_model["target"]["decoder"][layer_name][ + "self_attention" + ]["out"]["kernel"] + timesfmx_attention_query = timesfmx_model["target"]["decoder"][layer_name][ + "self_attention" + ]["query"]["kernel"] + timesfmx_attention_value = timesfmx_model["target"]["decoder"][layer_name][ + "self_attention" + ]["value"]["kernel"] # Layer Normalization - timesfmx_pre_attention_layer_norm = timesfmx_model["target"]["decoder"][layer_name]["pre_self_attention_layer_norm"][ - "scale" - ] + timesfmx_pre_attention_layer_norm = timesfmx_model["target"]["decoder"][ + layer_name + ]["pre_self_attention_layer_norm"]["scale"] # Encoder-Decoder-Attention - timesfmx_enc_dec_attention_key = timesfmx_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["key"][ - "kernel" - ] - timesfmx_enc_dec_attention_out = timesfmx_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["out"][ - "kernel" - ] - timesfmx_enc_dec_attention_query = timesfmx_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["query"][ - "kernel" - ] - timesfmx_enc_dec_attention_value = timesfmx_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["value"][ - "kernel" - ] + timesfmx_enc_dec_attention_key = timesfmx_model["target"]["decoder"][ + layer_name + ]["encoder_decoder_attention"]["key"]["kernel"] + timesfmx_enc_dec_attention_out = timesfmx_model["target"]["decoder"][ + layer_name + ]["encoder_decoder_attention"]["out"]["kernel"] + timesfmx_enc_dec_attention_query = timesfmx_model["target"]["decoder"][ + layer_name + ]["encoder_decoder_attention"]["query"]["kernel"] + timesfmx_enc_dec_attention_value = timesfmx_model["target"]["decoder"][ + layer_name + ]["encoder_decoder_attention"]["value"]["kernel"] # Layer Normalization - timesfmx_cross_layer_norm = timesfmx_model["target"]["decoder"][layer_name]["pre_cross_attention_layer_norm"]["scale"] + timesfmx_cross_layer_norm = timesfmx_model["target"]["decoder"][layer_name][ + "pre_cross_attention_layer_norm" + ]["scale"] # MLP if split_mlp_wi: - timesfmx_mlp_wi_0 = timesfmx_model["target"]["decoder"][layer_name]["mlp"]["wi_0"]["kernel"] - timesfmx_mlp_wi_1 = timesfmx_model["target"]["decoder"][layer_name]["mlp"]["wi_1"]["kernel"] + timesfmx_mlp_wi_0 = timesfmx_model["target"]["decoder"][layer_name]["mlp"][ + "wi_0" + ]["kernel"] + timesfmx_mlp_wi_1 = timesfmx_model["target"]["decoder"][layer_name]["mlp"][ + "wi_1" + ]["kernel"] else: - timesfmx_mlp_wi = timesfmx_model["target"]["decoder"][layer_name]["mlp"]["wi"]["kernel"] + timesfmx_mlp_wi = timesfmx_model["target"]["decoder"][layer_name]["mlp"][ + "wi" + ]["kernel"] - timesfmx_mlp_wo = timesfmx_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"] + timesfmx_mlp_wo = timesfmx_model["target"]["decoder"][layer_name]["mlp"]["wo"][ + "kernel" + ] # Layer Normalization - tx5_mlp_layer_norm = timesfmx_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"] + tx5_mlp_layer_norm = timesfmx_model["target"]["decoder"][layer_name][ + "pre_mlp_layer_norm" + ]["scale"] # Assigning - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = ( - timesfmx_attention_key - ) - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = ( - timesfmx_attention_out - ) - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = ( - timesfmx_attention_query - ) - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = ( - timesfmx_attention_value - ) - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = ( - timesfmx_pre_attention_layer_norm - ) - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["k"]["kernel"] = ( - timesfmx_enc_dec_attention_key - ) - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["o"]["kernel"] = ( - timesfmx_enc_dec_attention_out - ) - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["q"]["kernel"] = ( - timesfmx_enc_dec_attention_query - ) - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["v"]["kernel"] = ( - timesfmx_enc_dec_attention_value - ) - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = ( - timesfmx_cross_layer_norm - ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ + "SelfAttention" + ]["k"]["kernel"] = timesfmx_attention_key + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ + "SelfAttention" + ]["o"]["kernel"] = timesfmx_attention_out + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ + "SelfAttention" + ]["q"]["kernel"] = timesfmx_attention_query + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ + "SelfAttention" + ]["v"]["kernel"] = timesfmx_attention_value + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ + "layer_norm" + ]["weight"] = timesfmx_pre_attention_layer_norm + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ + "EncDecAttention" + ]["k"]["kernel"] = timesfmx_enc_dec_attention_key + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ + "EncDecAttention" + ]["o"]["kernel"] = timesfmx_enc_dec_attention_out + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ + "EncDecAttention" + ]["q"]["kernel"] = timesfmx_enc_dec_attention_query + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ + "EncDecAttention" + ]["v"]["kernel"] = timesfmx_enc_dec_attention_value + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ + "layer_norm" + ]["weight"] = timesfmx_cross_layer_norm if split_mlp_wi: - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_0"][ - "kernel" - ] = timesfmx_mlp_wi_0 - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_1"][ - "kernel" - ] = timesfmx_mlp_wi_1 + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"][ + "DenseReluDense" + ]["wi_0"]["kernel"] = timesfmx_mlp_wi_0 + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"][ + "DenseReluDense" + ]["wi_1"]["kernel"] = timesfmx_mlp_wi_1 else: - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi"]["kernel"] = ( - timesfmx_mlp_wi - ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"][ + "DenseReluDense" + ]["wi"]["kernel"] = timesfmx_mlp_wi - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wo"]["kernel"] = ( - timesfmx_mlp_wo - ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"][ + "DenseReluDense" + ]["wo"]["kernel"] = timesfmx_mlp_wo - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["layer_norm"]["weight"] = ( - tx5_mlp_layer_norm - ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"][ + "layer_norm" + ]["weight"] = tx5_mlp_layer_norm # Decoder Normalization tx5_decoder_norm = timesfmx_model["target"]["decoder"]["decoder_norm"]["scale"] flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm # Only for layer 0: - timesfmx_decoder_rel_embedding = timesfmx_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T - flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][ - "embedding" - ] = timesfmx_decoder_rel_embedding + timesfmx_decoder_rel_embedding = timesfmx_model["target"]["decoder"]["relpos_bias"][ + "rel_embedding" + ].T + flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"][ + "relative_attention_bias" + ]["embedding"] = timesfmx_decoder_rel_embedding # Token Embeddings tx5_token_embeddings = timesfmx_model["target"]["token_embedder"]["embedding"] @@ -215,7 +261,9 @@ def convert_timesfmx_checkpoint_to_flax(timesfmx_checkpoint_path, config_name, f # LM Head (only in v1.1 checkpoints) if "logits_dense" in timesfmx_model["target"]["decoder"]: - flax_model.params["lm_head"]["kernel"] = timesfmx_model["target"]["decoder"]["logits_dense"]["kernel"] + flax_model.params["lm_head"]["kernel"] = timesfmx_model["target"]["decoder"][ + "logits_dense" + ]["kernel"] flax_model.save_pretrained(flax_dump_folder_path) print("TimesFMX Model was sucessfully converted!") @@ -225,11 +273,27 @@ def convert_timesfmx_checkpoint_to_flax(timesfmx_checkpoint_path, config_name, f parser = argparse.ArgumentParser() # Required parameters parser.add_argument( - "--timesfmx_checkpoint_path", default=None, type=str, required=True, help="Path the TX5 checkpoint." + "--timesfmx_checkpoint_path", + default=None, + type=str, + required=True, + help="Path the TX5 checkpoint.", + ) + parser.add_argument( + "--config_name", + default=None, + type=str, + required=True, + help="Config name of TimesFM model.", ) - parser.add_argument("--config_name", default=None, type=str, required=True, help="Config name of TimesFM model.") parser.add_argument( - "--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model." + "--flax_dump_folder_path", + default=None, + type=str, + required=True, + help="Path to the output FLAX model.", ) args = parser.parse_args() - convert_timesfmx_checkpoint_to_flax(args.timesfmx_checkpoint_path, args.config_name, args.flax_dump_folder_path) + convert_timesfmx_checkpoint_to_flax( + args.timesfmx_checkpoint_path, args.config_name, args.flax_dump_folder_path + ) diff --git a/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_pytorch.py b/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_pytorch.py index b761d76bbdcd..8d5f13535e8d 100644 --- a/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_pytorch.py +++ b/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_pytorch.py @@ -35,7 +35,11 @@ from flax import traverse_util from timesfmx import checkpoints -from transformers import TimesFMConfig, TimesFMEncoderModel, TimesFMForConditionalGeneration +from transformers import ( + TimesFMConfig, + TimesFMEncoderModel, + TimesFMForConditionalGeneration, +) from transformers.utils import logging @@ -69,7 +73,9 @@ def timesfmx_layer_norm_lookup(params, i, prefix, layer_name): return params[f"{prefix}/layers_{i}/{layer_name}/scale"] -def convert_timesfmx_to_pytorch(variables: dict, *, num_layers: int, num_decoder_layers: int, is_encoder_only: bool): +def convert_timesfmx_to_pytorch( + variables: dict, *, num_layers: int, num_decoder_layers: int, is_encoder_only: bool +): """Converts the parameters from TimesFMX-Flax to Transformers-PyTorch.""" old = traverse_util.flatten_dict(variables["target"]) old = {"/".join(k): v for k, v in old.items()} @@ -86,7 +92,9 @@ def convert_timesfmx_to_pytorch(variables: dict, *, num_layers: int, num_decoder # Encoder. for i in range(num_layers): # Block i, layer 0 (Self Attention). - layer_norm = timesfmx_layer_norm_lookup(old, i, "encoder", "pre_attention_layer_norm") + layer_norm = timesfmx_layer_norm_lookup( + old, i, "encoder", "pre_attention_layer_norm" + ) k, o, q, v = timesfmx_attention_lookup(old, i, "encoder", "attention") new[f"encoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm new[f"encoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T @@ -114,7 +122,9 @@ def convert_timesfmx_to_pytorch(variables: dict, *, num_layers: int, num_decoder # Decoder. for i in range(num_decoder_layers): # Block i, layer 0 (Self Attention). - layer_norm = timesfmx_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm") + layer_norm = timesfmx_layer_norm_lookup( + old, i, "decoder", "pre_self_attention_layer_norm" + ) k, o, q, v = timesfmx_attention_lookup(old, i, "decoder", "self_attention") new[f"decoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm new[f"decoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T @@ -123,8 +133,12 @@ def convert_timesfmx_to_pytorch(variables: dict, *, num_layers: int, num_decoder new[f"decoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T # Block i, layer 1 (Cross Attention). - layer_norm = timesfmx_layer_norm_lookup(old, i, "decoder", "pre_cross_attention_layer_norm") - k, o, q, v = timesfmx_attention_lookup(old, i, "decoder", "encoder_decoder_attention") + layer_norm = timesfmx_layer_norm_lookup( + old, i, "decoder", "pre_cross_attention_layer_norm" + ) + k, o, q, v = timesfmx_attention_lookup( + old, i, "decoder", "encoder_decoder_attention" + ) new[f"decoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm new[f"decoder.block.{i}.layer.1.EncDecAttention.k.weight"] = k.T new[f"decoder.block.{i}.layer.1.EncDecAttention.o.weight"] = o.T @@ -132,7 +146,9 @@ def convert_timesfmx_to_pytorch(variables: dict, *, num_layers: int, num_decoder new[f"decoder.block.{i}.layer.1.EncDecAttention.v.weight"] = v.T # Block i, layer 2 (MLP). - layer_norm = timesfmx_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm") + layer_norm = timesfmx_layer_norm_lookup( + old, i, "decoder", "pre_mlp_layer_norm" + ) wi, wo = timesfmx_mlp_lookup(old, i, "decoder", split_mlp_wi) new[f"decoder.block.{i}.layer.2.layer_norm.weight"] = layer_norm if split_mlp_wi: @@ -143,9 +159,9 @@ def convert_timesfmx_to_pytorch(variables: dict, *, num_layers: int, num_decoder new[f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight"] = wo.T new["decoder.final_layer_norm.weight"] = old["decoder/decoder_norm/scale"] - new["decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = old[ - "decoder/relpos_bias/rel_embedding" - ].T + new["decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = ( + old["decoder/relpos_bias/rel_embedding"].T + ) # LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead) if "decoder/logits_dense/kernel" in old: @@ -157,7 +173,9 @@ def convert_timesfmx_to_pytorch(variables: dict, *, num_layers: int, num_decoder def make_state_dict(converted_params, is_encoder_only: bool): """Prepares a state dict for the PyTorch model.""" # Make a state dict with torch tensors. - state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()]) + state_dict = collections.OrderedDict( + [(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()] + ) # Add what is missing. if "encoder.embed_tokens.weight" not in state_dict: @@ -174,7 +192,9 @@ def make_state_dict(converted_params, is_encoder_only: bool): return state_dict -def load_timesfmx_weights_in_timesfm(model, config, timesfmx_checkpoint_path, is_encoder_only): +def load_timesfmx_weights_in_timesfm( + model, config, timesfmx_checkpoint_path, is_encoder_only +): """Replaces the params in model witht the TimesFMX converted params.""" variables = checkpoints.load_timesfmx_checkpoint(timesfmx_checkpoint_path) converted = convert_timesfmx_to_pytorch( @@ -188,7 +208,10 @@ def load_timesfmx_weights_in_timesfm(model, config, timesfmx_checkpoint_path, is def convert_timesfmx_checkpoint_to_pytorch( - timesfmx_checkpoint_path, config_file, pytorch_dump_path, is_encoder_only: bool = False + timesfmx_checkpoint_path, + config_file, + pytorch_dump_path, + is_encoder_only: bool = False, ): """Loads the config and model, converts the TimesFMX checkpoint, and saves a PyTorch checkpoint.""" # Initialise PyTorch model @@ -202,7 +225,9 @@ def convert_timesfmx_checkpoint_to_pytorch( model = TimesFMForConditionalGeneration(config) # Load weights from tf checkpoint - load_timesfmx_weights_in_timesfm(model, config, timesfmx_checkpoint_path, is_encoder_only) + load_timesfmx_weights_in_timesfm( + model, config, timesfmx_checkpoint_path, is_encoder_only + ) # Save pytorch-model print(f"Save PyTorch model to {pytorch_dump_path}") @@ -214,10 +239,16 @@ def convert_timesfmx_checkpoint_to_pytorch( if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Converts a native TimesFMX checkpoint into a PyTorch checkpoint.") + parser = argparse.ArgumentParser( + description="Converts a native TimesFMX checkpoint into a PyTorch checkpoint." + ) # Required parameters parser.add_argument( - "--timesfmx_checkpoint_path", default=None, type=str, required=True, help="Path to the TimesFMX checkpoint." + "--timesfmx_checkpoint_path", + default=None, + type=str, + required=True, + help="Path to the TimesFMX checkpoint.", ) parser.add_argument( "--config_file", @@ -227,12 +258,22 @@ def convert_timesfmx_checkpoint_to_pytorch( help="The config json file corresponding to the pre-trained TimesFM model.\nThis specifies the model architecture.", ) parser.add_argument( - "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + "--pytorch_dump_path", + default=None, + type=str, + required=True, + help="Path to the output PyTorch model.", ) parser.add_argument( - "--is_encoder_only", action="store_true", help="Check if the model is encoder-decoder model", default=False + "--is_encoder_only", + action="store_true", + help="Check if the model is encoder-decoder model", + default=False, ) args = parser.parse_args() convert_timesfmx_checkpoint_to_pytorch( - args.timesfmx_checkpoint_path, args.config_file, args.pytorch_dump_path, args.is_encoder_only + args.timesfmx_checkpoint_path, + args.config_file, + args.pytorch_dump_path, + args.is_encoder_only, ) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 1a2d93b67694..8542c54277bc 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -36,7 +36,11 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + find_pruneable_heads_and_indices, + prune_linear_layer, +) from ...utils import ( DUMMY_INPUTS, DUMMY_MASK, @@ -96,7 +100,14 @@ def load_tf_weights_in_timesfm(model, config, tf_checkpoint_path): # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # which are not required for using pretrained model if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + n + in [ + "adam_v", + "adam_m", + "AdamWeightDecayOptimizer", + "AdamWeightDecayOptimizer_1", + "global_step", + ] for n in name ): logger.info(f"Skipping {'/'.join(name)}") @@ -140,7 +151,11 @@ def load_tf_weights_in_timesfm(model, config, tf_checkpoint_path): continue elif scope_names[0] == "logits": pointer = getattr(pointer, "lm_head") - elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): + elif ( + scope_names[0] == "wi" + and len(scope_names) > 1 + and scope_names[1].isdigit() + ): pointer = getattr(pointer, f"wi_{scope_names[1]}") continue else: @@ -159,7 +174,9 @@ def load_tf_weights_in_timesfm(model, config, tf_checkpoint_path): array = np.transpose(array) try: if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + raise ValueError( + f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + ) except AssertionError as e: e.args += (pointer.shape, array.shape) raise @@ -260,12 +277,16 @@ def forward(self, hidden_states): TimesFMLayerNorm = FusedRMSNorm # noqa - logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of TimesFMLayerNorm") + logger.info( + "Discovered apex.normalization.FusedRMSNorm - will use it instead of TimesFMLayerNorm" + ) except ImportError: # using the normal TimesFMLayerNorm pass except Exception: - logger.warning("discovered apex but it failed to load, falling back to TimesFMLayerNorm") + logger.warning( + "discovered apex but it failed to load, falling back to TimesFMLayerNorm" + ) pass ALL_LAYERNORM_LAYERS.append(TimesFMLayerNorm) @@ -326,17 +347,21 @@ def forward(self, seq_length=None, position=None): position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0) else: if position.ndim != 2: - raise ValueError(f"position should have 2 dimensions, got {position.ndim}") + raise ValueError( + f"position should have 2 dimensions, got {position.ndim}" + ) num_timescales = self.embedding_dims // 2 - log_timescale_increment = math.log(float(self.max_timescale) / float(self.min_timescale)) / max( - torch.tensor(num_timescales, dtype=torch.float32) - 1, 1 - ) + log_timescale_increment = math.log( + float(self.max_timescale) / float(self.min_timescale) + ) / max(torch.tensor(num_timescales, dtype=torch.float32) - 1, 1) inv_timescales = self.min_timescale * torch.exp( torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment ) scaled_time = position.unsqueeze(2) * inv_timescales.unsqueeze(0).unsqueeze(0) - signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2).type(torch.float32) + signal = torch.cat( + [torch.sin(scaled_time), torch.cos(scaled_time)], dim=2 + ).type(torch.float32) signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) return signal @@ -413,7 +438,9 @@ def __init__(self, config: TimesFMConfig, has_relative_attention_bias=False): self.per_head_dim_scale = TimesFMPerHeadDimScale(config) if self.has_relative_attention_bias: - self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.relative_attention_bias = nn.Embedding( + self.relative_attention_num_buckets, self.n_heads + ) self.pruned_heads = set() self.gradient_checkpointing = False @@ -434,7 +461,9 @@ def prune_heads(self, heads): self.pruned_heads = self.pruned_heads.union(heads) @staticmethod - def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 + ): """ Adapted from Mesh Tensorflow: https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 @@ -461,7 +490,9 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets relative_buckets += (relative_position > 0).to(torch.long) * num_buckets relative_position = torch.abs(relative_position) else: - relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + relative_position = -torch.min( + relative_position, torch.zeros_like(relative_position) + ) # now relative_position is in the range [0, inf) # half of the buckets are for exact increments in positions @@ -475,27 +506,40 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets * (num_buckets - max_exact) ).to(torch.long) relative_position_if_large = torch.min( - relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1), ) - relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + relative_buckets += torch.where( + is_small, relative_position, relative_position_if_large + ) return relative_buckets def compute_bias(self, query_length, key_length, device=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] - memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] - relative_position = memory_position - context_position # shape (query_length, key_length) + context_position = torch.arange(query_length, dtype=torch.long, device=device)[ + :, None + ] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[ + None, : + ] + relative_position = ( + memory_position - context_position + ) # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( relative_position, # shape (query_length, key_length) bidirectional=(not self.is_decoder), num_buckets=self.relative_attention_num_buckets, max_distance=self.relative_attention_max_distance, ) - values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) - values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + values = self.relative_attention_bias( + relative_position_bucket + ) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze( + 0 + ) # shape (1, num_heads, query_length, key_length) return values def forward( @@ -525,17 +569,25 @@ def forward( raise ValueError( f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + real_seq_length += ( + past_key_value[0].shape[2] if query_length is None else query_length + ) - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + key_length = ( + real_seq_length if key_value_states is None else key_value_states.shape[1] + ) def shape(states): """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + return states.view( + batch_size, -1, self.n_heads, self.key_value_proj_dim + ).transpose(1, 2) def unshape(states): """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + return ( + states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + ) def project(hidden_states, proj_layer, key_value_states, past_key_value): """projects hidden states correctly to key/query states""" @@ -565,15 +617,23 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): return hidden_states # get query states - unscaled_query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + unscaled_query_states = shape( + self.q(hidden_states) + ) # (batch_size, n_heads, seq_length, dim_per_head) query_states = self.per_head_dim_scale(unscaled_query_states) # get key/value states key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + hidden_states, + self.k, + key_value_states, + past_key_value[0] if past_key_value is not None else None, ) value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + hidden_states, + self.v, + key_value_states, + past_key_value[1] if past_key_value is not None else None, ) # compute scores @@ -584,12 +644,16 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): if position_bias is None: if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, real_seq_length, key_length), + device=scores.device, + dtype=scores.dtype, ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device + ) # if key and values are already calculated # we want only the last query position bias @@ -597,7 +661,9 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias = position_bias[:, :, -hidden_states.size(1) :, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + position_bias = ( + position_bias + mask + ) # (batch_size, n_heads, seq_length, key_length) if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -618,10 +684,14 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = unshape( + torch.matmul(attn_weights, value_states) + ) # (batch_size, seq_length, dim) attn_output = self.o(attn_output) - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None + present_key_value_state = ( + (key_states, value_states) if (self.is_decoder and use_cache) else None + ) outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) if output_attentions: @@ -633,8 +703,12 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): class TimesFMLayerSelfAttention(nn.Module): def __init__(self, config, has_relative_attention_bias=False): super().__init__() - self.SelfAttention = TimesFMAttention(config, has_relative_attention_bias=has_relative_attention_bias) - self.layer_norm = TimesFMLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.SelfAttention = TimesFMAttention( + config, has_relative_attention_bias=has_relative_attention_bias + ) + self.layer_norm = TimesFMLayerNorm( + config.d_model, eps=config.layer_norm_epsilon + ) self.dropout = nn.Dropout(config.dropout_rate) def forward( @@ -658,7 +732,9 @@ def forward( output_attentions=output_attentions, ) hidden_states = hidden_states + self.dropout(attention_output[0]) - outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + outputs = (hidden_states,) + attention_output[ + 1: + ] # add attentions if we output them return outputs @@ -666,8 +742,12 @@ def forward( class TimesFMLayerCrossAttention(nn.Module): def __init__(self, config): super().__init__() - self.EncDecAttention = TimesFMAttention(config, has_relative_attention_bias=False) - self.layer_norm = TimesFMLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.EncDecAttention = TimesFMAttention( + config, has_relative_attention_bias=False + ) + self.layer_norm = TimesFMLayerNorm( + config.d_model, eps=config.layer_norm_epsilon + ) self.dropout = nn.Dropout(config.dropout_rate) def forward( @@ -695,7 +775,9 @@ def forward( output_attentions=output_attentions, ) layer_output = hidden_states + self.dropout(attention_output[0]) - outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + outputs = (layer_output,) + attention_output[ + 1: + ] # add attentions if we output them return outputs @@ -705,7 +787,11 @@ def __init__(self, config, has_relative_attention_bias=False): super().__init__() self.is_decoder = config.is_decoder self.layer = nn.ModuleList() - self.layer.append(TimesFMLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + self.layer.append( + TimesFMLayerSelfAttention( + config, has_relative_attention_bias=has_relative_attention_bias + ) + ) if self.is_decoder: self.layer.append(TimesFMLayerCrossAttention(config)) @@ -728,7 +814,9 @@ def forward( ): if past_key_value is not None: if not self.is_decoder: - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") + logger.warning( + "`past_key_values` is passed to the encoder. Please make sure this is intended." + ) expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 if len(past_key_value) != expected_num_past_key_values: @@ -753,7 +841,9 @@ def forward( output_attentions=output_attentions, ) hidden_states, present_key_value_state = self_attention_outputs[:2] - attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + attention_outputs = self_attention_outputs[ + 2: + ] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -762,7 +852,9 @@ def forward( torch.finfo(hidden_states.dtype).max - 1000, torch.finfo(hidden_states.dtype).max, ) - hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: @@ -793,11 +885,15 @@ def forward( torch.finfo(hidden_states.dtype).max - 1000, torch.finfo(hidden_states.dtype).max, ) - hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) # Combine self attn and cross attn key value states if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] + present_key_value_state = ( + present_key_value_state + cross_attention_outputs[1] + ) # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] @@ -812,7 +908,9 @@ def forward( torch.finfo(hidden_states.dtype).max - 1000, torch.finfo(hidden_states.dtype).max, ) - hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) outputs = (hidden_states,) @@ -871,12 +969,19 @@ def dummy_inputs(self): def _init_weights(self, module): """Initialize the weights""" - factor = self.config.initializer_factor # Used for testing weights initialization + factor = ( + self.config.initializer_factor + ) # Used for testing weights initialization if isinstance(module, TimesFMLayerNorm): module.weight.data.fill_(factor * 1.0) elif isinstance( module, - (TimesFMModel, TimesFMForConditionalGeneration, TimesFMEncoderModel, TimesFMForQuestionAnswering), + ( + TimesFMModel, + TimesFMForConditionalGeneration, + TimesFMEncoderModel, + TimesFMForQuestionAnswering, + ), ): # Mesh TensorFlow embeddings initialization # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 @@ -884,27 +989,37 @@ def _init_weights(self, module): if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "qa_outputs"): - module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.qa_outputs.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model) ** -0.5) + ) module.qa_outputs.bias.data.zero_() elif isinstance(module, TimesFMForTokenClassification): if hasattr(module, "classifier"): module.classifier.weight.data.normal_(mean=0.0, std=factor * 1.0) module.classifier.bias.data.zero_() elif isinstance(module, TimesFMClassificationHead): - module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.dense.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model) ** -0.5) + ) if hasattr(module.dense, "bias") and module.dense.bias is not None: module.dense.bias.data.zero_() - module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.out_proj.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model) ** -0.5) + ) if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: module.out_proj.bias.data.zero_() elif isinstance(module, TimesFMDenseActDense): # Mesh TensorFlow FF initialization # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 - module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model) ** -0.5) + ) if hasattr(module.wi, "bias") and module.wi.bias is not None: module.wi.bias.data.zero_() - module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + module.wo.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_ff) ** -0.5) + ) if hasattr(module.wo, "bias") and module.wo.bias is not None: module.wo.bias.data.zero_() elif isinstance(module, TimesFMPerHeadDimScale): @@ -915,12 +1030,18 @@ def _init_weights(self, module): d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.q.weight.data.normal_( + mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5) + ) module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + module.o.weight.data.normal_( + mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5) + ) if module.has_relative_attention_bias: - module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + module.relative_attention_bias.weight.data.normal_( + mean=0.0, std=factor * ((d_model) ** -0.5) + ) def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -935,8 +1056,12 @@ def _shift_right(self, input_ids): # shift inputs to the right if is_torch_fx_proxy(input_ids): # Item assignment is not supported natively for proxies. - shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) - shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + shifted_input_ids = torch.full( + input_ids.shape[:-1] + (1,), decoder_start_token_id + ) + shifted_input_ids = torch.cat( + [shifted_input_ids, input_ids[..., :-1]], dim=-1 + ) else: shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() @@ -957,7 +1082,10 @@ def __init__(self, config): self.is_decoder = config.is_decoder self.block = nn.ModuleList( - [TimesFMBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + [ + TimesFMBlock(config, has_relative_attention_bias=bool(i == 0)) + for i in range(config.num_layers) + ] ) self.dropout = nn.Dropout(config.dropout_rate) @@ -987,11 +1115,19 @@ def forward( torch.cuda.set_device(self.first_device) self.embed_tokens = self.embed_tokens.to(self.first_device) use_cache = use_cache if use_cache is not None else self.config.use_cache - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: err_msg_prefix = "decoder_" if self.is_decoder else "" @@ -1005,43 +1141,61 @@ def forward( input_shape = inputs_embeds.size()[:-1] else: err_msg_prefix = "decoder_" if self.is_decoder else "" - raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + raise ValueError( + f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds" + ) if inputs_embeds is None: if self.embed_tokens is None: - raise ValueError("You have to initialize the model with valid token embeddings") + raise ValueError( + "You have to initialize the model with valid token embeddings" + ) inputs_embeds = self.embed_tokens(input_ids) batch_size, seq_length = input_shape # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + mask_seq_length = ( + past_key_values[0][0].shape[2] + seq_length + if past_key_values is not None + else seq_length + ) if use_cache is True: if not self.is_decoder: - raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") + raise ValueError( + f"`use_cache` can only be set to `True` if {self} is used as a decoder" + ) # initialize past_key_values with `None` if past does not exist if past_key_values is None: past_key_values = [None] * len(self.block) if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + attention_mask = torch.ones( + batch_size, mask_seq_length, device=inputs_embeds.device + ) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape + ) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_batch_size, encoder_sequence_length, _ = ( + encoder_hidden_states.size() + ) encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones( encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long ) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) else: encoder_extended_attention_mask = None @@ -1054,7 +1208,9 @@ def forward( # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) - cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask( + cross_attn_head_mask, self.config.num_layers + ) present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -1064,7 +1220,9 @@ def forward( hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, (layer_module, past_key_value) in enumerate( + zip(self.block, past_key_values) + ): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] # Model parallel @@ -1076,15 +1234,23 @@ def forward( if position_bias is not None: position_bias = position_bias.to(hidden_states.device) if encoder_hidden_states is not None: - encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + encoder_hidden_states = encoder_hidden_states.to( + hidden_states.device + ) if encoder_extended_attention_mask is not None: - encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + encoder_extended_attention_mask = ( + encoder_extended_attention_mask.to(hidden_states.device) + ) if encoder_decoder_position_bias is not None: - encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + encoder_decoder_position_bias = encoder_decoder_position_bias.to( + hidden_states.device + ) if layer_head_mask is not None: layer_head_mask = layer_head_mask.to(hidden_states.device) if cross_attn_layer_head_mask is not None: - cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to( + hidden_states.device + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -1130,10 +1296,14 @@ def forward( # (cross-attention position bias), (cross-attention weights) position_bias = layer_outputs[2] if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + encoder_decoder_position_bias = layer_outputs[ + 4 if output_attentions else 3 + ] # append next layer key value states if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) + present_key_value_states = present_key_value_states + ( + present_key_value_state, + ) if output_attentions: all_attentions = all_attentions + (layer_outputs[3],) @@ -1433,7 +1603,9 @@ class PreTrainedModel self.encoder.layer[layer].attention.prune_heads(heads) @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings( + output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC + ) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1477,7 +1649,9 @@ def forward( >>> last_hidden_states = outputs.last_hidden_state ```""" use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask if head_mask is not None and decoder_head_mask is None: @@ -1514,7 +1688,9 @@ def forward( if attention_mask is not None: attention_mask = attention_mask.to(self.decoder.first_device) if decoder_attention_mask is not None: - decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + decoder_attention_mask = decoder_attention_mask.to( + self.decoder.first_device + ) # Decode decoder_outputs = self.decoder( @@ -1547,12 +1723,18 @@ def forward( ) -@add_start_docstrings("""TIMESFM Model with a `language modeling` head on top.""", TIMESFM_START_DOCSTRING) +@add_start_docstrings( + """TIMESFM Model with a `language modeling` head on top.""", TIMESFM_START_DOCSTRING +) class TimesFMForConditionalGeneration(TimesFMPreTrainedModel): _keys_to_ignore_on_load_unexpected = [ "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", ] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + "lm_head.weight", + ] def __init__(self, config: TimesFMConfig): super().__init__(config) @@ -1642,7 +1824,9 @@ def get_decoder(self): return self.decoder @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings( + output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC + ) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1694,7 +1878,9 @@ def forward( >>> # studies have shown that owning a dog is good for you. ```""" use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask if head_mask is not None and decoder_head_mask is None: @@ -1726,7 +1912,11 @@ def forward( if self.model_parallel: torch.cuda.set_device(self.decoder.first_device) - if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + if ( + labels is not None + and decoder_input_ids is None + and decoder_inputs_embeds is None + ): # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) @@ -1739,7 +1929,9 @@ def forward( if attention_mask is not None: attention_mask = attention_mask.to(self.decoder.first_device) if decoder_attention_mask is not None: - decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + decoder_attention_mask = decoder_attention_mask.to( + self.decoder.first_device + ) # Decode decoder_outputs = self.decoder( @@ -1841,7 +2033,9 @@ def _reorder_cache(self, past_key_values, beam_idx): # if decoder past is not included in output # speedy decoding is disabled and no need to reorder if past_key_values is None: - logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + logger.warning( + "You might want to consider setting `use_cache=True` to speed up decoding" + ) return past_key_values reordered_decoder_past = () @@ -1852,7 +2046,9 @@ def _reorder_cache(self, past_key_values, beam_idx): for layer_past_state in layer_past_states: # need to set correct `past` for each of the four key / value states reordered_layer_past_states = reordered_layer_past_states + ( - layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), + layer_past_state.index_select( + 0, beam_idx.to(layer_past_state.device) + ), ) if reordered_layer_past_states[0].shape != layer_past_states[0].shape: @@ -1864,7 +2060,9 @@ def _reorder_cache(self, past_key_values, beam_idx): f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched" ) - reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + reordered_decoder_past = reordered_decoder_past + ( + reordered_layer_past_states, + ) return reordered_decoder_past @@ -1902,7 +2100,9 @@ def get_encoder(self): return self.encoder @add_start_docstrings_to_model_forward(TIMESFM_ENCODER_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings( + output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC + ) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1929,7 +2129,9 @@ def forward( >>> outputs = model(input_ids=input_ids) >>> last_hidden_states = outputs.last_hidden_state ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) encoder_outputs = self.encoder( input_ids=input_ids, @@ -1952,7 +2154,9 @@ def forward( TIMESFM_START_DOCSTRING, ) class TimesFMForSequenceClassification(TimesFMPreTrainedModel): - _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] + _keys_to_ignore_on_load_unexpected = [ + "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight" + ] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] def __init__(self, config: TimesFMConfig): @@ -1966,7 +2170,9 @@ def __init__(self, config: TimesFMConfig): self.model_parallel = False @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Seq2SeqSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings( + output_type=Seq2SeqSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC + ) def forward( self, input_ids: torch.LongTensor = None, @@ -1991,7 +2197,9 @@ def forward( config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). Returns: """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) if labels is not None: use_cache = False @@ -2033,7 +2241,9 @@ def forward( if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: raise ValueError("All examples must have the same number of tokens.") batch_size, _, hidden_size = sequence_output.shape - sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :] + sentence_representation = sequence_output[eos_mask, :].view( + batch_size, -1, hidden_size + )[:, -1, :] logits = self.classification_head(sentence_representation) loss = None @@ -2042,7 +2252,9 @@ def forward( if self.config.problem_type is None: if self.config.num_labels == 1: self.config.problem_type = "regression" - elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + elif self.config.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" @@ -2055,7 +2267,9 @@ def forward( loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + loss = loss_fct( + logits.view(-1, self.config.num_labels), labels.view(-1) + ) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) @@ -2098,7 +2312,9 @@ def __init__(self, config: TimesFMConfig): self.post_init() @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings( + output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC + ) def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -2115,7 +2331,9 @@ def forward( Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Returns: """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) outputs = self.transformer( input_ids, @@ -2156,7 +2374,9 @@ def forward( TIMESFM_START_DOCSTRING, ) class TimesFMForQuestionAnswering(TimesFMPreTrainedModel): - _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] + _keys_to_ignore_on_load_unexpected = [ + "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight" + ] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] def __init__(self, config: TimesFMConfig): @@ -2205,7 +2425,9 @@ def get_decoder(self): return self.decoder @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Seq2SeqQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings( + output_type=Seq2SeqQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC + ) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -2236,7 +2458,9 @@ def forward( are not taken into account for computing the loss. Returns: """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) use_cache = use_cache if use_cache is not None else self.config.use_cache if start_positions is not None and end_positions is not None: use_cache = False @@ -2253,7 +2477,9 @@ def forward( decoder_input_ids = self._shift_right(input_ids) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask if head_mask is not None and decoder_head_mask is None: From 651ebf294201515a282fa829c44cd94e9a87cc01 Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Thu, 5 Sep 2024 15:49:34 -0700 Subject: [PATCH 008/242] remove unneeded modules --- src/transformers/__init__.py | 14 +- src/transformers/models/auto/modeling_auto.py | 10 +- .../models/auto/tokenization_auto.py | 6 - src/transformers/models/timesfm/__init__.py | 14 +- .../models/timesfm/modeling_timesfm.py | 512 +----------------- tests/models/timesfm/test_modeling_timesfm.py | 156 +----- 6 files changed, 33 insertions(+), 679 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 4ac465164d69..5bb79f8c539d 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3359,14 +3359,9 @@ ) _import_structure["models.timesfm"].extend( [ - "TimesFMEncoderModel", - "TimesFMForConditionalGeneration", - "TimesFMForQuestionAnswering", - "TimesFMForSequenceClassification", - "TimesFMForTokenClassification", + "TimesFMForPrediction", "TimesFMModel", "TimesFMPreTrainedModel", - "load_tf_weights_in_timesfm", ] ) _import_structure["models.table_transformer"].extend( @@ -7766,14 +7761,9 @@ load_tf_weights_in_t5, ) from .models.timesfm import ( - TimesFMEncoderModel, - TimesFMForConditionalGeneration, - TimesFMForQuestionAnswering, - TimesFMForSequenceClassification, - TimesFMForTokenClassification, + TimesFMForPrediction, TimesFMModel, TimesFMPreTrainedModel, - load_tf_weights_in_timesfm, ) from .models.table_transformer import ( TableTransformerForObjectDetection, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index a3919a2005c8..16b7ea2dbe2d 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -340,7 +340,7 @@ ("squeezebert", "SqueezeBertForMaskedLM"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), - ("timesfm", "TimesFMForConditionalGeneration"), + ("timesfm", "TimesFMForPrediction"), ("tapas", "TapasForMaskedLM"), ("transfo-xl", "TransfoXLLMHeadModel"), ("tvlt", "TvltForPreTraining"), @@ -434,7 +434,7 @@ ("squeezebert", "SqueezeBertForMaskedLM"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), - ("timesfm", "TimesFMForConditionalGeneration"), + ("timesfm", "TimesFMForPrediction"), ("tapas", "TapasForMaskedLM"), ("transfo-xl", "TransfoXLLMHeadModel"), ("wav2vec2", "Wav2Vec2ForMaskedLM"), @@ -848,7 +848,7 @@ ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToText"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), - ("timesfm", "TimesFMForConditionalGeneration"), + ("timesfm", "TimesFMForPrediction"), ("umt5", "UMT5ForConditionalGeneration"), ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"), ] @@ -948,7 +948,6 @@ ("stablelm", "StableLmForSequenceClassification"), ("starcoder2", "Starcoder2ForSequenceClassification"), ("t5", "T5ForSequenceClassification"), - ("timesfm", "TimesFMForSequenceClassification"), ("tapas", "TapasForSequenceClassification"), ("transfo-xl", "TransfoXLForSequenceClassification"), ("umt5", "UMT5ForSequenceClassification"), @@ -1021,7 +1020,6 @@ ("splinter", "SplinterForQuestionAnswering"), ("squeezebert", "SqueezeBertForQuestionAnswering"), ("t5", "T5ForQuestionAnswering"), - ("timesfm", "TimesFMForQuestionAnswering"), ("umt5", "UMT5ForQuestionAnswering"), ("xlm", "XLMForQuestionAnsweringSimple"), ("xlm-roberta", "XLMRobertaForQuestionAnswering"), @@ -1122,7 +1120,6 @@ ("stablelm", "StableLmForTokenClassification"), ("starcoder2", "Starcoder2ForTokenClassification"), ("t5", "T5ForTokenClassification"), - ("timesfm", "TimesFMForTokenClassification"), ("umt5", "UMT5ForTokenClassification"), ("xlm", "XLMForTokenClassification"), ("xlm-roberta", "XLMRobertaForTokenClassification"), @@ -1344,7 +1341,6 @@ ("roformer", "RoFormerModel"), ("squeezebert", "SqueezeBertModel"), ("t5", "T5EncoderModel"), - ("timesfm", "TimesFMEncoderModel"), ("umt5", "UMT5EncoderModel"), ("xlm", "XLMModel"), ("xlm-roberta", "XLMRobertaModel"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 25055a5baa20..b094f50b5e97 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -472,12 +472,6 @@ "T5TokenizerFast" if is_tokenizers_available() else None, ), ), - ( - "timesfm", - ( - "T5Tokenizer" if is_sentencepiece_available() else None, - "T5TokenizerFast" if is_tokenizers_available() else None, - ), ("tapas", ("TapasTokenizer", None)), ("tapex", ("TapexTokenizer", None)), ("transfo-xl", ("TransfoXLTokenizer", None)), diff --git a/src/transformers/models/timesfm/__init__.py b/src/transformers/models/timesfm/__init__.py index 1abef3d3e175..baa30b11af21 100644 --- a/src/transformers/models/timesfm/__init__.py +++ b/src/transformers/models/timesfm/__init__.py @@ -30,14 +30,9 @@ pass else: _import_structure["modeling_timesfm"] = [ - "TimesFMEncoderModel", - "TimesFMForConditionalGeneration", + "TimesFMForPrediction", "TimesFMModel", "TimesFMPreTrainedModel", - "load_tf_weights_in_timesfm", - "TimesFMForQuestionAnswering", - "TimesFMForSequenceClassification", - "TimesFMForTokenClassification", ] if TYPE_CHECKING: @@ -50,14 +45,9 @@ pass else: from .modeling_timesfm import ( - TimesFMEncoderModel, - TimesFMForConditionalGeneration, - TimesFMForQuestionAnswering, - TimesFMForSequenceClassification, - TimesFMForTokenClassification, + TimesFMForPrediction, TimesFMModel, TimesFMPreTrainedModel, - load_tf_weights_in_timesfm, ) else: diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 8542c54277bc..2d35ecdeeca3 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -31,9 +31,6 @@ BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, - Seq2SeqQuestionAnsweringModelOutput, - Seq2SeqSequenceClassifierOutput, - TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ( @@ -978,9 +975,7 @@ def _init_weights(self, module): module, ( TimesFMModel, - TimesFMForConditionalGeneration, - TimesFMEncoderModel, - TimesFMForQuestionAnswering, + TimesFMForPrediction, ), ): # Mesh TensorFlow embeddings initialization @@ -993,10 +988,6 @@ def _init_weights(self, module): mean=0.0, std=factor * ((self.config.d_model) ** -0.5) ) module.qa_outputs.bias.data.zero_() - elif isinstance(module, TimesFMForTokenClassification): - if hasattr(module, "classifier"): - module.classifier.weight.data.normal_(mean=0.0, std=factor * 1.0) - module.classifier.bias.data.zero_() elif isinstance(module, TimesFMClassificationHead): module.dense.weight.data.normal_( mean=0.0, std=factor * ((self.config.d_model) ** -0.5) @@ -1726,7 +1717,7 @@ def forward( @add_start_docstrings( """TIMESFM Model with a `language modeling` head on top.""", TIMESFM_START_DOCSTRING ) -class TimesFMForConditionalGeneration(TimesFMPreTrainedModel): +class TimesFMForPrediction(TimesFMPreTrainedModel): _keys_to_ignore_on_load_unexpected = [ "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", ] @@ -2064,502 +2055,3 @@ def _reorder_cache(self, past_key_values, beam_idx): reordered_layer_past_states, ) return reordered_decoder_past - - -@add_start_docstrings( - "The bare TIMESFM Model transformer outputting encoder's raw hidden-states without any specific head on top.", - TIMESFM_START_DOCSTRING, -) -class TimesFMEncoderModel(TimesFMPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight"] - _keys_to_ignore_on_load_unexpected = [r"decoder"] - - def __init__(self, config: TimesFMConfig): - super().__init__(config) - self.freq_emb = nn.Embedding(config.freq_size, config.d_model) - self.horizon_ff_layer = TimesFMResidualBlock( - input_dims=config.d_model, - output_dims=config.horizon_len * (1 + len(config.quantiles)), - hidden_dims=config.d_ff, - dropout=config.dropout_rate, - ) - - encoder_config = copy.deepcopy(config) - encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False - self.encoder = TimesFMStack(encoder_config, self.shared) - - # Initialize weights and apply final processing - self.post_init() - - # Model parallel - self.model_parallel = False - self.device_map = None - - def get_encoder(self): - return self.encoder - - @add_start_docstrings_to_model_forward(TIMESFM_ENCODER_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: - r""" - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, TimesFMEncoderModel - - >>> tokenizer = AutoTokenizer.from_pretrained("google/timesfm-1.0-200m") - >>> model = TimesFMEncoderModel.from_pretrained("google/timesfm-1.0-200m") - >>> input_ids = tokenizer( - ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" - ... ).input_ids # Batch size 1 - >>> outputs = model(input_ids=input_ids) - >>> last_hidden_states = outputs.last_hidden_state - ```""" - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - return encoder_outputs - - -@add_start_docstrings( - """ - TIMESFM model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE - tasks. - """, - TIMESFM_START_DOCSTRING, -) -class TimesFMForSequenceClassification(TimesFMPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [ - "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight" - ] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] - - def __init__(self, config: TimesFMConfig): - super().__init__(config) - self.transformer = TimesFMModel(config) - self.classification_head = TimesFMClassificationHead(config) - - # Initialize weights and apply final processing - self.post_init() - - self.model_parallel = False - - @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=Seq2SeqSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC - ) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - decoder_head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - Returns: - """ - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - if labels is not None: - use_cache = False - - if input_ids is None and inputs_embeds is not None: - raise NotImplementedError( - f"Passing input embeddings is currently not supported for {self.__class__.__name__}" - ) - - # decoder_input_ids from input_ids if no decoder_input_ids are provided - if decoder_input_ids is None and decoder_inputs_embeds is None: - if input_ids is None: - raise ValueError( - "If no `decoder_input_ids` or `decoder_inputs_embeds` are " - "passed, `input_ids` cannot be `None`. Please pass either " - "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." - ) - decoder_input_ids = self._shift_right(input_ids) - - outputs = self.transformer( - input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - encoder_outputs=encoder_outputs, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - sequence_output = outputs[0] - - eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device) - - if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: - raise ValueError("All examples must have the same number of tokens.") - batch_size, _, hidden_size = sequence_output.shape - sentence_representation = sequence_output[eos_mask, :].view( - batch_size, -1, hidden_size - )[:, -1, :] - logits = self.classification_head(sentence_representation) - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.config.num_labels == 1: - self.config.problem_type = "regression" - elif self.config.num_labels > 1 and ( - labels.dtype == torch.long or labels.dtype == torch.int - ): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.config.num_labels == 1: - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct( - logits.view(-1, self.config.num_labels), labels.view(-1) - ) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits, labels) - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return Seq2SeqSequenceClassifierOutput( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - -@add_start_docstrings( - """ - TIMESFM Encoder Model with a token classification head on top (a linear layer on top of the hidden-states output) - e.g. for Named-Entity-Recognition (NER) tasks. - """, - TIMESFM_START_DOCSTRING, -) -class TimesFMForTokenClassification(TimesFMPreTrainedModel): - _tied_weights_keys = ["transformer.encoder.embed_tokens.weight"] - - def __init__(self, config: TimesFMConfig): - super().__init__(config) - self.num_labels = config.num_labels - - self.transformer = TimesFMEncoderModel(config) - self.dropout = nn.Dropout(config.classifier_dropout) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC - ) - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - Returns: - """ - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - outputs = self.transformer( - input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - hidden_states = self.dropout(hidden_states) - logits = self.classifier(hidden_states) - - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits, outputs[2:-1]) - return ((loss,) + output) if loss is not None else output - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - TIMESFM Model with a span classification head on top for extractive question-answering tasks like SQuAD (linear layers - on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - TIMESFM_START_DOCSTRING, -) -class TimesFMForQuestionAnswering(TimesFMPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [ - "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight" - ] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] - - def __init__(self, config: TimesFMConfig): - super().__init__(config) - self.model_dim = config.d_model - - self.shared = nn.Embedding(config.vocab_size, config.d_model) - - encoder_config = copy.deepcopy(config) - encoder_config.is_decoder = False - encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False - self.encoder = TimesFMStack(encoder_config, self.shared) - - decoder_config = copy.deepcopy(config) - decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False - decoder_config.num_layers = config.num_decoder_layers - self.decoder = TimesFMStack(decoder_config, self.shared) - - self.num_labels = config.num_labels - self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - self.model_parallel = False - - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, new_embeddings): - self.shared = new_embeddings - self.encoder.set_input_embeddings(new_embeddings) - self.decoder.set_input_embeddings(new_embeddings) - - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) - self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) - - def get_encoder(self): - return self.encoder - - def get_decoder(self): - return self.decoder - - @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=Seq2SeqQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.BoolTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - decoder_head_mask: Optional[torch.FloatTensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, - start_positions: Optional[torch.LongTensor] = None, - end_positions: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]: - r""" - start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence - are not taken into account for computing the loss. - Returns: - """ - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - if start_positions is not None and end_positions is not None: - use_cache = False - - # different to other models, TIMESFM automatically creates decoder_input_ids from - # input_ids if no decoder_input_ids are provided - if decoder_input_ids is None and decoder_inputs_embeds is None: - if input_ids is None: - raise ValueError( - "If no `decoder_input_ids` or `decoder_inputs_embeds` are " - "passed, `input_ids` cannot be `None`. Please pass either " - "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." - ) - decoder_input_ids = self._shift_right(input_ids) - - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask - if head_mask is not None and decoder_head_mask is None: - if self.config.num_layers == self.config.num_decoder_layers: - warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) - decoder_head_mask = head_mask - - # Encode if needed (training, first prediction pass) - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - - hidden_states = encoder_outputs[0] - - # Decode - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - inputs_embeds=decoder_inputs_embeds, - past_key_values=None, - encoder_hidden_states=hidden_states, - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = decoder_outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1).contiguous() - end_logits = end_logits.squeeze(-1).contiguous() - - total_loss = None - if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1).to(start_logits.device) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1).to(end_logits.device) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - start_positions = start_positions.clamp(0, ignored_index) - end_positions = end_positions.clamp(0, ignored_index) - - loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 - - if not return_dict: - output = (start_logits, end_logits) + decoder_outputs[1:] + encoder_outputs - return ((total_loss,) + output) if total_loss is not None else output - - return Seq2SeqQuestionAnsweringModelOutput( - loss=total_loss, - start_logits=start_logits, - end_logits=end_logits, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index e5878f8c51c7..e08277fac50f 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -21,7 +21,6 @@ import unittest from transformers import TimesFMConfig, is_torch_available -from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES from transformers.testing_utils import ( require_accelerate, require_sentencepiece, @@ -48,11 +47,7 @@ from transformers import ( AutoTokenizer, ByT5Tokenizer, - TimesFMEncoderModel, - TimesFMForConditionalGeneration, - TimesFMForQuestionAnswering, - TimesFMForSequenceClassification, - TimesFMForTokenClassification, + TimesFMForPrediction, TimesFMModel, T5Tokenizer, ) @@ -249,7 +244,7 @@ def create_and_check_with_lm_head( decoder_attention_mask, lm_labels, ): - model = TimesFMForConditionalGeneration(config=config).to(torch_device).eval() + model = TimesFMForPrediction(config=config).to(torch_device).eval() outputs = model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, @@ -260,26 +255,6 @@ def create_and_check_with_lm_head( self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size)) self.parent.assertEqual(outputs["loss"].size(), ()) - def create_and_check_with_sequence_classification_head( - self, - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ): - labels = torch.tensor([1] * self.batch_size, dtype=torch.long, device=torch_device) - model = TimesFMForSequenceClassification(config=config).to(torch_device).eval() - outputs = model( - input_ids=input_ids, - decoder_input_ids=input_ids, - labels=labels, - ) - # self.parent.assertEqual(len(outputs), 4) - self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, config.num_labels)) - self.parent.assertEqual(outputs["loss"].size(), ()) - def create_and_check_decoder_model_past( self, config, @@ -415,7 +390,7 @@ def create_and_check_generate_with_past_key_values( decoder_attention_mask, lm_labels, ): - model = TimesFMForConditionalGeneration(config=config).to(torch_device).eval() + model = TimesFMForPrediction(config=config).to(torch_device).eval() torch.manual_seed(0) output_without_past_cache = model.generate( input_ids[:1], num_beams=2, max_length=5, do_sample=True, use_cache=False @@ -446,7 +421,7 @@ def create_and_check_encoder_decoder_shared_weights( decoder_attention_mask, lm_labels, ): - for model_class in [TimesFMModel, TimesFMForConditionalGeneration]: + for model_class in [TimesFMModel, TimesFMForPrediction]: torch.manual_seed(0) model = model_class(config=config).to(torch_device).eval() # load state dict copies weights but does not tie them @@ -520,7 +495,7 @@ def check_resize_embeddings_timesfm_v1_1( prev_vocab_size = config.vocab_size config.tie_word_embeddings = False - model = TimesFMForConditionalGeneration(config=config).to(torch_device).eval() + model = TimesFMForPrediction(config=config).to(torch_device).eval() model.resize_token_embeddings(prev_vocab_size - 10) self.parent.assertEqual(model.get_input_embeddings().weight.shape[0], prev_vocab_size - 10) @@ -551,12 +526,12 @@ def prepare_config_and_inputs_for_common(self): @require_torch class TimesFMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - (TimesFMModel, TimesFMForConditionalGeneration, TimesFMForSequenceClassification, TimesFMForQuestionAnswering) + (TimesFMModel, TimesFMForPrediction) if is_torch_available() else () ) - all_generative_model_classes = (TimesFMForConditionalGeneration,) if is_torch_available() else () - all_parallelizable_model_classes = (TimesFMModel, TimesFMForConditionalGeneration) if is_torch_available() else () + all_generative_model_classes = (TimesFMForPrediction,) if is_torch_available() else () + all_parallelizable_model_classes = (TimesFMModel, TimesFMForPrediction) if is_torch_available() else () fx_compatible = False test_pruning = False test_resize_embeddings = True @@ -573,7 +548,7 @@ def setUp(self): def test_inputs_embeds(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in (TimesFMModel, TimesFMForConditionalGeneration, TimesFMForQuestionAnswering): + for model_class in (TimesFMModel, TimesFMForPrediction): model = model_class(config) model.to(torch_device) model.eval() @@ -609,10 +584,6 @@ def test_with_lm_head(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_with_lm_head(*config_and_inputs) - def test_with_sequence_classification_head(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_with_sequence_classification_head(*config_and_inputs) - def test_decoder_model_past(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_decoder_model_past(*config_and_inputs) @@ -695,7 +666,7 @@ def test_generate_with_head_masking(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() config = config_and_inputs[0] max_length = config_and_inputs[1].shape[-1] + 3 - model = TimesFMForConditionalGeneration(config).eval() + model = TimesFMForPrediction(config).eval() model.to(torch_device) head_masking = { @@ -799,50 +770,6 @@ def prepare_config_and_inputs(self): attention_mask, ) - def create_and_check_model( - self, - config, - input_ids, - attention_mask, - ): - model = TimesFMEncoderModel(config=config) - model.to(torch_device) - model.eval() - result = model( - input_ids=input_ids, - attention_mask=attention_mask, - ) - result = model(input_ids=input_ids) - encoder_output = result.last_hidden_state - - self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size)) - - def create_and_check_model_fp16_forward( - self, - config, - input_ids, - attention_mask, - ): - model = TimesFMEncoderModel(config=config).to(torch_device).half().eval() - output = model(input_ids, attention_mask=attention_mask)["last_hidden_state"] - self.parent.assertFalse(torch.isnan(output).any().item()) - - def create_and_check_with_token_classification_head( - self, - config, - input_ids, - attention_mask, - ): - labels = torch.tensor([1] * self.seq_length * self.batch_size, dtype=torch.long, device=torch_device) - model = TimesFMForTokenClassification(config=config).to(torch_device).eval() - outputs = model( - input_ids=input_ids, - labels=labels, - attention_mask=attention_mask, - ) - self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.seq_length, config.num_labels)) - self.parent.assertEqual(outputs["loss"].size(), ()) - def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -858,41 +785,6 @@ def prepare_config_and_inputs_for_common(self): return config, inputs_dict -class TimesFMEncoderOnlyModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = (TimesFMEncoderModel, TimesFMForTokenClassification) if is_torch_available() else () - test_pruning = False - test_resize_embeddings = False - test_model_parallel = True - pipeline_model_mapping = ( - { - "token-classification": TimesFMForTokenClassification, - } - if is_torch_available() - else {} - ) - all_parallelizable_model_classes = (TimesFMEncoderModel,) if is_torch_available() else () - - def setUp(self): - self.model_tester = TimesFMEncoderOnlyModelTester(self) - self.config_tester = ConfigTester(self, config_class=TimesFMConfig, d_model=37) - - def test_config(self): - self.config_tester.run_common_tests() - - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - @unittest.skipIf(torch_device == "cpu", "Cant do half precision") - def test_model_fp16_forward(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) - - def test_with_token_classification_head(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_with_token_classification_head(*config_and_inputs) - - def use_task_specific_params(model, task): model.config.update(model.config.task_specific_params[task]) @@ -922,38 +814,38 @@ def import_accelerate_mock(name, *args, **kwargs): with unittest.mock.patch("builtins.__import__", side_effect=import_accelerate_mock): accelerate_available = False - model = TimesFMForConditionalGeneration.from_pretrained("google/timesfm-1.0-200m", torch_dtype=torch.float16) + model = TimesFMForPrediction.from_pretrained("google/timesfm-1.0-200m", torch_dtype=torch.float16) self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16) # Load without in bf16 - model = TimesFMForConditionalGeneration.from_pretrained("google/timesfm-1.0-200m", torch_dtype=torch.bfloat16) + model = TimesFMForPrediction.from_pretrained("google/timesfm-1.0-200m", torch_dtype=torch.bfloat16) self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) # Load using `accelerate` in bf16 - model = TimesFMForConditionalGeneration.from_pretrained( + model = TimesFMForPrediction.from_pretrained( "google/timesfm-1.0-200m", torch_dtype=torch.bfloat16, device_map="auto" ) self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) # Load using `accelerate` in bf16 - model = TimesFMForConditionalGeneration.from_pretrained( + model = TimesFMForPrediction.from_pretrained( "google/timesfm-1.0-200m", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True ) self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) # Load without using `accelerate` - model = TimesFMForConditionalGeneration.from_pretrained( + model = TimesFMForPrediction.from_pretrained( "google/timesfm-1.0-200m", torch_dtype=torch.float16, low_cpu_mem_usage=True ) self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16) # Load using `accelerate` - model = TimesFMForConditionalGeneration.from_pretrained( + model = TimesFMForPrediction.from_pretrained( "google/timesfm-1.0-200m", torch_dtype=torch.float16, device_map="auto" ) self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) @@ -966,7 +858,7 @@ def import_accelerate_mock(name, *args, **kwargs): class TimesFMModelIntegrationTests(unittest.TestCase): @cached_property def model(self): - return TimesFMForConditionalGeneration.from_pretrained("google-timesfm/timesfm-base").to(torch_device) + return TimesFMForPrediction.from_pretrained("google-timesfm/timesfm-base").to(torch_device) @cached_property def tokenizer(self): @@ -979,7 +871,7 @@ def test_torch_quant(self): """ model_name = "google/flan-timesfm-small" tokenizer = T5Tokenizer.from_pretrained(model_name) - model = TimesFMForConditionalGeneration.from_pretrained(model_name) + model = TimesFMForPrediction.from_pretrained(model_name) model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) input_text = "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?" input_ids = tokenizer(input_text, return_tensors="pt").input_ids @@ -987,7 +879,7 @@ def test_torch_quant(self): @slow def test_small_generation(self): - model = TimesFMForConditionalGeneration.from_pretrained("google/timesfm-1.0-200m").to(torch_device) + model = TimesFMForPrediction.from_pretrained("google/timesfm-1.0-200m").to(torch_device) model.config.max_length = 8 model.config.num_beams = 1 model.config.do_sample = False @@ -1014,7 +906,7 @@ def test_small_integration_test(self): >>> score = timesfm_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) """ - model = TimesFMForConditionalGeneration.from_pretrained("google/timesfm-1.0-200m").to(torch_device) + model = TimesFMForPrediction.from_pretrained("google/timesfm-1.0-200m").to(torch_device) tokenizer = T5Tokenizer.from_pretrained("google/timesfm-1.0-200m") input_ids = tokenizer("Hello there", return_tensors="pt").input_ids @@ -1040,7 +932,7 @@ def test_small_v1_1_integration_test(self): >>> score = timesfm_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) """ - model = TimesFMForConditionalGeneration.from_pretrained("google/timesfm-v1_1-small").to(torch_device) + model = TimesFMForPrediction.from_pretrained("google/timesfm-v1_1-small").to(torch_device) tokenizer = T5Tokenizer.from_pretrained("google/timesfm-v1_1-small") input_ids = tokenizer("Hello there", return_tensors="pt").input_ids @@ -1064,7 +956,7 @@ def test_small_bytimesfm_integration_test(self): >>> score = timesfm_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) """ - model = TimesFMForConditionalGeneration.from_pretrained("google/bytimesfm-small").to(torch_device) + model = TimesFMForPrediction.from_pretrained("google/bytimesfm-small").to(torch_device) tokenizer = ByT5Tokenizer.from_pretrained("google/bytimesfm-small") input_ids = tokenizer("Hello there", return_tensors="pt").input_ids @@ -1405,7 +1297,7 @@ def test_contrastive_search_timesfm(self): ) article = "summarize: " + article.strip() timesfm_tokenizer = AutoTokenizer.from_pretrained("flax-community/timesfm-base-cnn-dm") - timesfm_model = TimesFMForConditionalGeneration.from_pretrained("flax-community/timesfm-base-cnn-dm").to(torch_device) + timesfm_model = TimesFMForPrediction.from_pretrained("flax-community/timesfm-base-cnn-dm").to(torch_device) input_ids = timesfm_tokenizer( article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt" ).input_ids.to(torch_device) @@ -1434,7 +1326,7 @@ def build_model_and_check_forward_pass(self, **kwargs): decoder_attention_mask, lm_labels, ) = inputs - model = TimesFMForConditionalGeneration(config=config).to(torch_device).eval() + model = TimesFMForPrediction(config=config).to(torch_device).eval() outputs = model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, From 5219e5943228ac7ac7cc6333f3e3be9e59802c0a Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Tue, 10 Sep 2024 19:30:18 -0700 Subject: [PATCH 009/242] TimesFM Model --- .../models/timesfm/configuration_timesfm.py | 8 +- .../models/timesfm/modeling_timesfm.py | 1445 ++--------------- 2 files changed, 107 insertions(+), 1346 deletions(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index cdee64d7f377..16da290cc0cb 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -47,18 +47,18 @@ class TimesFMConfig(PretrainedConfig): The tolerance for the quantile loss. freq_size (`int`, *optional*, defaults to 3): The number of frequency embeddings. - d_model (`int`, *optional*, defaults to 512): + d_model (`int`, *optional*, defaults to 1280): Size of the encoder layers and the pooler layer. - d_kv (`int`, *optional*, defaults to 64): + d_kv (`int`, *optional*, defaults to 80): Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will be defined as `num_heads * d_kv`. d_ff (`int`, *optional*, defaults to 1280): Size of the intermediate feed forward layer in each `TimesFMBlock`. - num_layers (`int`, *optional*, defaults to 6): + num_layers (`int`, *optional*, defaults to 20): Number of hidden layers in the Transformer encoder. num_decoder_layers (`int`, *optional*): Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set. - num_heads (`int`, *optional*, defaults to 8): + num_heads (`int`, *optional*, defaults to 16): Number of attention heads for each attention layer in the Transformer encoder. relative_attention_num_buckets (`int`, *optional*, defaults to 32): The number of buckets to use for each attention layer. diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 2d35ecdeeca3..852d91320889 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -56,134 +56,6 @@ _CONFIG_FOR_DOC = "TimesFMConfig" _CHECKPOINT_FOR_DOC = "google/timesfm-1.0-200m" -#################################################### -# This dict contains ids and associated url -# for the pretrained weights provided with the models -#################################################### - - -#################################################### -# This is a conversion method from TF 1.0 to PyTorch -# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 -#################################################### -# Copied from transformers.models.t5.modeling_t5.load_tf_weights_in_t5 with t5->timesfm -def load_tf_weights_in_timesfm(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - tf_weights = {} - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - tf_weights[name] = array - - for txt_name in names: - name = txt_name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n - in [ - "adam_v", - "adam_m", - "AdamWeightDecayOptimizer", - "AdamWeightDecayOptimizer_1", - "global_step", - ] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - tf_weights.pop(txt_name, None) - continue - if "_slot_" in name[-1]: - logger.info(f"Skipping {'/'.join(name)}") - tf_weights.pop(txt_name, None) - continue - pointer = model - array = tf_weights[txt_name] - - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] in ["kernel", "scale", "embedding"]: - pointer = getattr(pointer, "weight") - elif scope_names[0] == "self_attention": - pointer = getattr(pointer, "layer") - pointer = pointer[0] - elif scope_names[0] == "enc_dec_attention": - pointer = getattr(pointer, "layer") - pointer = pointer[1] - elif scope_names[0] == "dense_relu_dense": - pointer = getattr(pointer, "layer") - pointer = pointer[2] - elif scope_names[0] == "rms_norm": - if hasattr(pointer, "layer_norm"): - pointer = getattr(pointer, "layer_norm") - elif hasattr(pointer, "final_layer_norm"): - pointer = getattr(pointer, "final_layer_norm") - elif scope_names[0] == "scale": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - elif scope_names[0] == "decoder" and name[1] == "logits": - continue - elif scope_names[0] == "logits": - pointer = getattr(pointer, "lm_head") - elif ( - scope_names[0] == "wi" - and len(scope_names) > 1 - and scope_names[1].isdigit() - ): - pointer = getattr(pointer, f"wi_{scope_names[1]}") - continue - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if scope_names[0] not in ["kernel", "scale", "embedding"]: - pointer = getattr(pointer, "weight") - if scope_names[0] != "embedding": - logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") - array = np.transpose(array) - try: - if pointer.shape != array.shape: - raise ValueError( - f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" - ) - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array.astype(np.float32)) - tf_weights.pop(txt_name, None) - - logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") - return model - #################################################### # PyTorch Models are constructed by sub-classing @@ -696,642 +568,139 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): return outputs -# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->TimesFM -class TimesFMLayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): - super().__init__() - self.SelfAttention = TimesFMAttention( - config, has_relative_attention_bias=has_relative_attention_bias - ) - self.layer_norm = TimesFMLayerNorm( - config.d_model, eps=config.layer_norm_epsilon - ) - self.dropout = nn.Dropout(config.dropout_rate) - - def forward( - self, - hidden_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - past_key_value=None, - use_cache=False, - output_attentions=False, - ): - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.SelfAttention( - normed_hidden_states, - mask=attention_mask, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - ) - hidden_states = hidden_states + self.dropout(attention_output[0]) - outputs = (hidden_states,) + attention_output[ - 1: - ] # add attentions if we output them - return outputs - - -# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->TimesFM -class TimesFMLayerCrossAttention(nn.Module): - def __init__(self, config): +class TimesFMTransformerLayer(nn.Module): + def __init__(self, config: TimesFMConfig): super().__init__() - self.EncDecAttention = TimesFMAttention( - config, has_relative_attention_bias=False - ) - self.layer_norm = TimesFMLayerNorm( - config.d_model, eps=config.layer_norm_epsilon - ) + self.attention = TimesFMAttention(config) + self.ff = TimesFMLayerFF(config) + self.layer_norm = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) - def forward( - self, - hidden_states, - key_value_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - past_key_value=None, - use_cache=False, - query_length=None, - output_attentions=False, - ): - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.EncDecAttention( - normed_hidden_states, - mask=attention_mask, - key_value_states=key_value_states, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - query_length=query_length, - output_attentions=output_attentions, - ) - layer_output = hidden_states + self.dropout(attention_output[0]) - outputs = (layer_output,) + attention_output[ - 1: - ] # add attentions if we output them - return outputs + def forward(self, inputs, mask=None): + x = self.layer_norm(inputs) + x = self.attention(x, mask=mask) + x = self.dropout(x) + x = x + inputs + x = self.ff(x) + return x -# Copied from transformers.models.t5.modeling_t5.T5Block with T5->TimesFM -class TimesFMBlock(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): +class TimesFMTransformerStack(nn.Module): + def __init__(self, config: TimesFMConfig): super().__init__() - self.is_decoder = config.is_decoder - self.layer = nn.ModuleList() - self.layer.append( - TimesFMLayerSelfAttention( - config, has_relative_attention_bias=has_relative_attention_bias - ) + self.layers = nn.ModuleList( + [TimesFMTransformerLayer(config) for _ in range(config.num_layers)] ) - if self.is_decoder: - self.layer.append(TimesFMLayerCrossAttention(config)) - self.layer.append(TimesFMLayerFF(config)) - - def forward( - self, - hidden_states, - attention_mask=None, - position_bias=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - encoder_decoder_position_bias=None, - layer_head_mask=None, - cross_attn_layer_head_mask=None, - past_key_value=None, - use_cache=False, - output_attentions=False, - return_dict=True, - ): - if past_key_value is not None: - if not self.is_decoder: - logger.warning( - "`past_key_values` is passed to the encoder. Please make sure this is intended." - ) - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + def forward(self, hidden_states, mask=None): + for layer in self.layers: + hidden_states = layer(hidden_states, mask=mask) + return hidden_states - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None +class TimesFMModel(PreTrainedModel): + def __init__(self, config: TimesFMConfig): + super().__init__(config) - self_attention_outputs = self.layer[0]( - hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, + self.freq_emb = nn.Embedding( + num_embeddings=config.freq_size, + embedding_dim=config.d_model, + ) + self.position_emb = TimesFMPositionalEmbedding( + embedding_dims=config.d_model, ) - hidden_states, present_key_value_state = self_attention_outputs[:2] - attention_outputs = self_attention_outputs[ - 2: - ] # Keep self-attention outputs and relative position weights - - # clamp inf values to enable fp16 training - if hidden_states.dtype == torch.float16: - clamp_value = torch.where( - torch.isinf(hidden_states).any(), - torch.finfo(hidden_states.dtype).max - 1000, - torch.finfo(hidden_states.dtype).max, - ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value - ) - - do_cross_attention = self.is_decoder and encoder_hidden_states is not None - if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None - - cross_attention_outputs = self.layer[1]( - hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - position_bias=encoder_decoder_position_bias, - layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - query_length=query_length, - use_cache=use_cache, - output_attentions=output_attentions, - ) - hidden_states = cross_attention_outputs[0] - - # clamp inf values to enable fp16 training - if hidden_states.dtype == torch.float16: - clamp_value = torch.where( - torch.isinf(hidden_states).any(), - torch.finfo(hidden_states.dtype).max - 1000, - torch.finfo(hidden_states.dtype).max, - ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value - ) - - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = ( - present_key_value_state + cross_attention_outputs[1] - ) - - # Keep cross-attention outputs and relative position weights - attention_outputs = attention_outputs + cross_attention_outputs[2:] - - # Apply Feed Forward layer - hidden_states = self.layer[-1](hidden_states) - - # clamp inf values to enable fp16 training - if hidden_states.dtype == torch.float16: - clamp_value = torch.where( - torch.isinf(hidden_states).any(), - torch.finfo(hidden_states.dtype).max - 1000, - torch.finfo(hidden_states.dtype).max, - ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value - ) - - outputs = (hidden_states,) - if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs - else: - outputs = outputs + attention_outputs + self.input_ff_layer = TimesFMResidualBlock( + input_dims=config.patch_len * 2, + output_dims=config.d_model, + hidden_dims=config.d_ff, + dropout=config.dropout_rate, + ) - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + self.stacked_transformer_layer = TimesFMTransformerStack(config) + def preprocess_inputs(self, inputs): + assert len(inputs.shape) == 3 # (batch_size, num_patches, patch_len) + inputs_mean = inputs.mean(dim=(1, 2)) + inputs_std = inputs.std(dim=(1, 2)) + processed_input = (inputs - inputs_mean[:, None, None]) / inputs_std[ + :, None, None + ] + return processed_input, (inputs_mean, inputs_std) -# Copied from transformers.models.t5.modeling_t5.T5ClassificationHead with T5->TimesFM -class TimesFMClassificationHead(nn.Module): - """Head for sentence-level classification tasks.""" + def create_causal_mask(batch_size, seq_len): + mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool() + mask = mask.unsqueeze(0).unsqueeze(1) + mask = mask.expand(batch_size, 1, seq_len, seq_len) + mask = mask.float().masked_fill(mask, -2.3819763e38).masked_fill(~mask, 0.0) + return mask + def forward( + self, + input_ts, + ): + batch_size = input_ts.shape[0] + patched_inputs = input_ts.reshape(batch_size, -1, self.config.patch_len) + patched_pads = torch.zeros_like(patched_inputs) + patched_inputs, input_stats = self.preprocess_inputs(patched_inputs) + concat_inputs = torch.concat([patched_inputs, patched_pads], dim=-1) + + model_input = self.input_ff_layer(concat_inputs) + position_emb = self.position_emb(seq_length=model_input.shape[1]).expand( + model_input.shape[0], -1, -1 + ) + model_input = model_input + position_emb + f_emb = self.freq_emb( + torch.zeros((batch_size, 1), dtype=torch.long) + ) # freq set to zero, change if needed + model_input = model_input + f_emb + mask = self.create_causal_mask(model_input.shape[0], model_input.shape[1]) + model_output = self.stacked_transformer_layer(model_input, mask=mask) + return model_output, input_stats + + +class TimesFMPredictionHead(nn.Module): def __init__(self, config: TimesFMConfig): super().__init__() - self.dense = nn.Linear(config.d_model, config.d_model) - self.dropout = nn.Dropout(p=config.classifier_dropout) - self.out_proj = nn.Linear(config.d_model, config.num_labels) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.dropout(hidden_states) - hidden_states = self.dense(hidden_states) - hidden_states = torch.tanh(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.out_proj(hidden_states) - return hidden_states - - -# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel with T5->TimesFM,t5->timesfm -class TimesFMPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = TimesFMConfig - load_tf_weights = load_tf_weights_in_timesfm - base_model_prefix = "transformer" - is_parallelizable = True - supports_gradient_checkpointing = True - _no_split_modules = ["TimesFMBlock"] - _keep_in_fp32_modules = ["wo"] - - @property - def dummy_inputs(self): - input_ids = torch.tensor(DUMMY_INPUTS) - input_mask = torch.tensor(DUMMY_MASK) - dummy_inputs = { - "decoder_input_ids": input_ids, - "input_ids": input_ids, - "decoder_attention_mask": input_mask, - } - return dummy_inputs - - def _init_weights(self, module): - """Initialize the weights""" - factor = ( - self.config.initializer_factor - ) # Used for testing weights initialization - if isinstance(module, TimesFMLayerNorm): - module.weight.data.fill_(factor * 1.0) - elif isinstance( - module, - ( - TimesFMModel, - TimesFMForPrediction, - ), - ): - # Mesh TensorFlow embeddings initialization - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 - module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) - if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: - module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) - if hasattr(module, "qa_outputs"): - module.qa_outputs.weight.data.normal_( - mean=0.0, std=factor * ((self.config.d_model) ** -0.5) - ) - module.qa_outputs.bias.data.zero_() - elif isinstance(module, TimesFMClassificationHead): - module.dense.weight.data.normal_( - mean=0.0, std=factor * ((self.config.d_model) ** -0.5) - ) - if hasattr(module.dense, "bias") and module.dense.bias is not None: - module.dense.bias.data.zero_() - module.out_proj.weight.data.normal_( - mean=0.0, std=factor * ((self.config.d_model) ** -0.5) - ) - if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: - module.out_proj.bias.data.zero_() - elif isinstance(module, TimesFMDenseActDense): - # Mesh TensorFlow FF initialization - # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 - # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 - module.wi.weight.data.normal_( - mean=0.0, std=factor * ((self.config.d_model) ** -0.5) - ) - if hasattr(module.wi, "bias") and module.wi.bias is not None: - module.wi.bias.data.zero_() - module.wo.weight.data.normal_( - mean=0.0, std=factor * ((self.config.d_ff) ** -0.5) - ) - if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.data.zero_() - elif isinstance(module, TimesFMPerHeadDimScale): - module.scale.data.zero_() - elif isinstance(module, TimesFMAttention): - # Mesh TensorFlow attention initialization to avoid scaling before softmax - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 - d_model = self.config.d_model - key_value_proj_dim = self.config.d_kv - n_heads = self.config.num_heads - module.q.weight.data.normal_( - mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5) - ) - module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.o.weight.data.normal_( - mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5) - ) - if module.has_relative_attention_bias: - module.relative_attention_bias.weight.data.normal_( - mean=0.0, std=factor * ((d_model) ** -0.5) - ) - - def _shift_right(self, input_ids): - decoder_start_token_id = self.config.decoder_start_token_id - pad_token_id = self.config.pad_token_id + self.config = config + self.horizon_ff_layer = TimesFMResidualBlock( + input_dims=config.d_model, + output_dims=config.horizon_len, + hidden_dims=config.d_ff, + dropout=config.dropout_rate, + ) - if decoder_start_token_id is None: - raise ValueError( - "self.model.config.decoder_start_token_id has to be defined. In TimesFM it is usually set to the pad_token_id. " - "See TimesFM docs for more information." - ) + def postprocess_outputs(self, outputs, stats): + mean, std = stats + return outputs * std[:, None, None, None] + mean[:, None, None, None] - # shift inputs to the right - if is_torch_fx_proxy(input_ids): - # Item assignment is not supported natively for proxies. - shifted_input_ids = torch.full( - input_ids.shape[:-1] + (1,), decoder_start_token_id - ) - shifted_input_ids = torch.cat( - [shifted_input_ids, input_ids[..., :-1]], dim=-1 - ) - else: - shifted_input_ids = input_ids.new_zeros(input_ids.shape) - shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() - shifted_input_ids[..., 0] = decoder_start_token_id + def forward(self, model_output, input_stats): + batch_size = model_output.shape[0] + output_ts = self.horizon_ff_layer(model_output) - if pad_token_id is None: - raise ValueError("self.model.config.pad_token_id has to be defined.") - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + assert self.config.d_model % self.config.horizon_len == 0 + num_outputs = self.config.d_model // self.config.horizon_len - return shifted_input_ids + output_ts = output_ts.reshape( + batch_size, -1, self.config.horizon_len, num_outputs + ) + output_ts = self.postprocess_outputs(output_ts, input_stats) + return output_ts -class TimesFMStack(TimesFMPreTrainedModel): - def __init__(self, config): +class TimesFMForPrediction(PreTrainedModel): + def __init__(self, config: TimesFMConfig): super().__init__(config) - - self.is_decoder = config.is_decoder - - self.block = nn.ModuleList( - [ - TimesFMBlock(config, has_relative_attention_bias=bool(i == 0)) - for i in range(config.num_layers) - ] - ) - self.dropout = nn.Dropout(config.dropout_rate) - - # Initialize weights and apply final processing - self.post_init() - # Model parallel - self.device_map = None - self.gradient_checkpointing = False + self.timesfm = TimesFMModel(config) + self.prediction_head = TimesFMPredictionHead(config) def forward( self, - input_ids=None, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - inputs_embeds=None, - head_mask=None, - cross_attn_head_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, + input_ts, ): - # Model parallel - if self.model_parallel: - torch.cuda.set_device(self.first_device) - self.embed_tokens = self.embed_tokens.to(self.first_device) - use_cache = use_cache if use_cache is not None else self.config.use_cache - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if input_ids is not None and inputs_embeds is not None: - err_msg_prefix = "decoder_" if self.is_decoder else "" - raise ValueError( - f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" - ) - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - err_msg_prefix = "decoder_" if self.is_decoder else "" - raise ValueError( - f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds" - ) - - if inputs_embeds is None: - if self.embed_tokens is None: - raise ValueError( - "You have to initialize the model with valid token embeddings" - ) - inputs_embeds = self.embed_tokens(input_ids) - - batch_size, seq_length = input_shape - - # required mask seq length can be calculated via length of past - mask_seq_length = ( - past_key_values[0][0].shape[2] + seq_length - if past_key_values is not None - else seq_length - ) - - if use_cache is True: - if not self.is_decoder: - raise ValueError( - f"`use_cache` can only be set to `True` if {self} is used as a decoder" - ) - - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) - - if attention_mask is None: - attention_mask = torch.ones( - batch_size, mask_seq_length, device=inputs_embeds.device - ) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask( - attention_mask, input_shape - ) - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = ( - encoder_hidden_states.size() - ) - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones( - encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long - ) - encoder_extended_attention_mask = self.invert_attention_mask( - encoder_attention_mask - ) - else: - encoder_extended_attention_mask = None - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # Prepare head mask if needed - head_mask = self.get_head_mask(head_mask, self.config.num_layers) - cross_attn_head_mask = self.get_head_mask( - cross_attn_head_mask, self.config.num_layers - ) - present_key_value_states = () if use_cache else None - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - all_cross_attentions = () if (output_attentions and self.is_decoder) else None - position_bias = None - encoder_decoder_position_bias = None - - hidden_states = self.dropout(inputs_embeds) - - for i, (layer_module, past_key_value) in enumerate( - zip(self.block, past_key_values) - ): - layer_head_mask = head_mask[i] - cross_attn_layer_head_mask = cross_attn_head_mask[i] - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) - if position_bias is not None: - position_bias = position_bias.to(hidden_states.device) - if encoder_hidden_states is not None: - encoder_hidden_states = encoder_hidden_states.to( - hidden_states.device - ) - if encoder_extended_attention_mask is not None: - encoder_extended_attention_mask = ( - encoder_extended_attention_mask.to(hidden_states.device) - ) - if encoder_decoder_position_bias is not None: - encoder_decoder_position_bias = encoder_decoder_position_bias.to( - hidden_states.device - ) - if layer_head_mask is not None: - layer_head_mask = layer_head_mask.to(hidden_states.device) - if cross_attn_layer_head_mask is not None: - cross_attn_layer_head_mask = cross_attn_layer_head_mask.to( - hidden_states.device - ) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.forward, - hidden_states, - extended_attention_mask, - position_bias, - encoder_hidden_states, - encoder_extended_attention_mask, - encoder_decoder_position_bias, - layer_head_mask, - cross_attn_layer_head_mask, - None, # past_key_value is always None with gradient checkpointing - use_cache, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask=extended_attention_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=layer_head_mask, - cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - # layer_outputs is a tuple with: - # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) - if use_cache is False: - layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - - hidden_states, present_key_value_state = layer_outputs[:2] - - # We share the position biases between the layers - the first layer store them - # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), - # (cross-attention position bias), (cross-attention weights) - position_bias = layer_outputs[2] - if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[ - 4 if output_attentions else 3 - ] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + ( - present_key_value_state, - ) - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[3],) - if self.is_decoder: - all_cross_attentions = all_cross_attentions + (layer_outputs[5],) - - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) - - hidden_states = self.dropout(hidden_states) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - present_key_value_states, - all_hidden_states, - all_attentions, - all_cross_attentions, - ] - if v is not None - ) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=present_key_value_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) + model_output, input_stats = self.timesfm(input_ts) + output_ts = self.prediction_head(model_output, input_stats) + return output_ts TIMESFM_START_DOCSTRING = r""" @@ -1447,611 +816,3 @@ def forward( return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ - -TIMESFM_ENCODER_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. TIMESFM is a model with relative position embeddings so you - should be able to pad the inputs on both the right and the left. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for detail. - - To know more on how to prepare `input_ids` for pretraining take a look a [TIMESFM Training](./timesfm#training). - attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask -__HEAD_MASK_WARNING_MSG = """ -The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, -`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. -If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, -num_heads)`. -""" - - -@add_start_docstrings( - "The bare TIMESFM Model transformer outputting raw hidden-states without any specific head on top.", - TIMESFM_START_DOCSTRING, -) -class TimesFMModel(TimesFMPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [ - "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", - ] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] - - def __init__(self, config: TimesFMConfig): - super().__init__(config) - self.freq_emb = nn.Embedding(config.freq_size, config.d_model) - self.horizon_ff_layer = TimesFMResidualBlock( - input_dims=config.d_model, - output_dims=config.horizon_len * (1 + len(config.quantiles)), - hidden_dims=config.d_ff, - dropout=config.dropout_rate, - ) - - encoder_config = copy.deepcopy(config) - encoder_config.is_decoder = False - encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False - self.encoder = TimesFMStack(encoder_config, self.shared) - - decoder_config = copy.deepcopy(config) - decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False - decoder_config.num_layers = config.num_decoder_layers - self.decoder = TimesFMStack(decoder_config, self.shared) - - # Initialize weights and apply final processing - self.post_init() - - # Model parallel - self.model_parallel = False - self.device_map = None - - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def parallelize(self, device_map=None): - warnings.warn( - "`TimesFMModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model" - " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" - " `device_map` but it needs to be a dictionary module_name to device, so for instance {'encoder.block.0':" - " 0, 'encoder.block.1': 1, ...}", - FutureWarning, - ) - self.device_map = ( - get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) - if device_map is None - else device_map - ) - assert_device_map(self.device_map, len(self.encoder.block)) - self.encoder.parallelize(self.device_map) - self.decoder.parallelize(self.device_map) - self.model_parallel = True - - @add_start_docstrings(DEPARALLELIZE_DOCSTRING) - def deparallelize(self): - warnings.warn( - "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", - FutureWarning, - ) - self.encoder.deparallelize() - self.decoder.deparallelize() - self.encoder = self.encoder.to("cpu") - self.decoder = self.decoder.to("cpu") - self.model_parallel = False - self.device_map = None - torch.cuda.empty_cache() - - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, new_embeddings): - self.shared = new_embeddings - self.encoder.set_input_embeddings(new_embeddings) - self.decoder.set_input_embeddings(new_embeddings) - - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) - self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) - - def get_encoder(self): - return self.encoder - - def get_decoder(self): - return self.decoder - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.encoder.layer[layer].attention.prune_heads(heads) - - @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.BoolTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - decoder_head_mask: Optional[torch.FloatTensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - decoder_inputs_embeds: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: - r""" - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, TimesFMModel - - >>> tokenizer = AutoTokenizer.from_pretrained("google/timesfm-1.0-200m") - >>> model = TimesFMModel.from_pretrained("google/timesfm-1.0-200m") - - >>> input_ids = tokenizer( - ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" - ... ).input_ids # Batch size 1 - >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 - - >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for TimesFMModel. - >>> # This is not needed for torch's TimesFMForConditionalGeneration as it does this internally using labels arg. - >>> decoder_input_ids = model._shift_right(decoder_input_ids) - - >>> # forward pass - >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) - >>> last_hidden_states = outputs.last_hidden_state - ```""" - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask - if head_mask is not None and decoder_head_mask is None: - if self.config.num_layers == self.config.num_decoder_layers: - warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) - decoder_head_mask = head_mask - - # Encode if needed (training, first prediction pass) - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - - hidden_states = encoder_outputs[0] - - # Set device for model parallelism - if self.model_parallel: - torch.cuda.set_device(self.decoder.first_device) - hidden_states = hidden_states.to(self.decoder.first_device) - if decoder_input_ids is not None: - decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) - if attention_mask is not None: - attention_mask = attention_mask.to(self.decoder.first_device) - if decoder_attention_mask is not None: - decoder_attention_mask = decoder_attention_mask.to( - self.decoder.first_device - ) - - # Decode - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - inputs_embeds=decoder_inputs_embeds, - past_key_values=past_key_values, - encoder_hidden_states=hidden_states, - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return Seq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - -@add_start_docstrings( - """TIMESFM Model with a `language modeling` head on top.""", TIMESFM_START_DOCSTRING -) -class TimesFMForPrediction(TimesFMPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [ - "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", - ] - _tied_weights_keys = [ - "encoder.embed_tokens.weight", - "decoder.embed_tokens.weight", - "lm_head.weight", - ] - - def __init__(self, config: TimesFMConfig): - super().__init__(config) - self.model_dim = config.d_model - - self.shared = nn.Embedding(config.vocab_size, config.d_model) - - encoder_config = copy.deepcopy(config) - encoder_config.is_decoder = False - encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False - self.encoder = TimesFMStack(encoder_config, self.shared) - - decoder_config = copy.deepcopy(config) - decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False - decoder_config.num_layers = config.num_decoder_layers - self.decoder = TimesFMStack(decoder_config, self.shared) - - self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - # Model parallel - self.model_parallel = False - self.device_map = None - - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def parallelize(self, device_map=None): - warnings.warn( - "`TimesFMForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you" - " should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also" - " provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance" - " {'encoder.block.0': 0, 'encoder.block.1': 1, ...}", - FutureWarning, - ) - self.device_map = ( - get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) - if device_map is None - else device_map - ) - assert_device_map(self.device_map, len(self.encoder.block)) - self.encoder.parallelize(self.device_map) - self.decoder.parallelize(self.device_map) - self.lm_head = self.lm_head.to(self.decoder.first_device) - self.model_parallel = True - - @add_start_docstrings(DEPARALLELIZE_DOCSTRING) - def deparallelize(self): - warnings.warn( - "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", - FutureWarning, - ) - self.encoder.deparallelize() - self.decoder.deparallelize() - self.encoder = self.encoder.to("cpu") - self.decoder = self.decoder.to("cpu") - self.lm_head = self.lm_head.to("cpu") - self.model_parallel = False - self.device_map = None - torch.cuda.empty_cache() - - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, new_embeddings): - self.shared = new_embeddings - self.encoder.set_input_embeddings(new_embeddings) - self.decoder.set_input_embeddings(new_embeddings) - - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) - self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def get_output_embeddings(self): - return self.lm_head - - def get_encoder(self): - return self.encoder - - def get_decoder(self): - return self.decoder - - @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.BoolTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - decoder_head_mask: Optional[torch.FloatTensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., - config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for - labels in `[0, ..., config.vocab_size]` - - Returns: - - Examples: - - ```python - >>> from transformers import AutoTokenizer, TimesFMForConditionalGeneration - - >>> tokenizer = AutoTokenizer.from_pretrained("google/timesfm-1.0-200m") - >>> model = TimesFMForConditionalGeneration.from_pretrained("google/timesfm-1.0-200m") - - >>> # training - >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids - >>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids - >>> outputs = model(input_ids=input_ids, labels=labels) - >>> loss = outputs.loss - >>> logits = outputs.logits - - >>> # inference - >>> input_ids = tokenizer( - ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" - ... ).input_ids # Batch size 1 - >>> outputs = model.generate(input_ids) - >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) - >>> # studies have shown that owning a dog is good for you. - ```""" - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask - if head_mask is not None and decoder_head_mask is None: - if self.config.num_layers == self.config.num_decoder_layers: - warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) - decoder_head_mask = head_mask - - # Encode if needed (training, first prediction pass) - if encoder_outputs is None: - # Convert encoder inputs in embeddings if needed - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - - hidden_states = encoder_outputs[0] - - if self.model_parallel: - torch.cuda.set_device(self.decoder.first_device) - - if ( - labels is not None - and decoder_input_ids is None - and decoder_inputs_embeds is None - ): - # get decoder inputs from shifting lm labels to the right - decoder_input_ids = self._shift_right(labels) - - # Set device for model parallelism - if self.model_parallel: - torch.cuda.set_device(self.decoder.first_device) - hidden_states = hidden_states.to(self.decoder.first_device) - if decoder_input_ids is not None: - decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) - if attention_mask is not None: - attention_mask = attention_mask.to(self.decoder.first_device) - if decoder_attention_mask is not None: - decoder_attention_mask = decoder_attention_mask.to( - self.decoder.first_device - ) - - # Decode - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - inputs_embeds=decoder_inputs_embeds, - past_key_values=past_key_values, - encoder_hidden_states=hidden_states, - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = decoder_outputs[0] - - # Set device for model parallelism - if self.model_parallel: - torch.cuda.set_device(self.encoder.first_device) - self.lm_head = self.lm_head.to(self.encoder.first_device) - sequence_output = sequence_output.to(self.lm_head.weight.device) - - if self.config.tie_word_embeddings: - # Rescale output before projecting on vocab - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 - sequence_output = sequence_output * (self.model_dim**-0.5) - - lm_logits = self.lm_head(sequence_output) - - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss(ignore_index=-100) - # move labels to correct device to enable PP - labels = labels.to(lm_logits.device) - loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) - # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 - - if not return_dict: - output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs - return ((loss,) + output) if loss is not None else output - - return Seq2SeqLMOutput( - loss=loss, - logits=lm_logits, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - decoder_attention_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - return { - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "decoder_attention_mask": decoder_attention_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - } - - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): - return self._shift_right(labels) - - def _reorder_cache(self, past_key_values, beam_idx): - # if decoder past is not included in output - # speedy decoding is disabled and no need to reorder - if past_key_values is None: - logger.warning( - "You might want to consider setting `use_cache=True` to speed up decoding" - ) - return past_key_values - - reordered_decoder_past = () - for layer_past_states in past_key_values: - # get the correct batch idx from layer past batch dim - # batch dim of `past` is at 2nd position - reordered_layer_past_states = () - for layer_past_state in layer_past_states: - # need to set correct `past` for each of the four key / value states - reordered_layer_past_states = reordered_layer_past_states + ( - layer_past_state.index_select( - 0, beam_idx.to(layer_past_state.device) - ), - ) - - if reordered_layer_past_states[0].shape != layer_past_states[0].shape: - raise ValueError( - f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched" - ) - if len(reordered_layer_past_states) != len(layer_past_states): - raise ValueError( - f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched" - ) - - reordered_decoder_past = reordered_decoder_past + ( - reordered_layer_past_states, - ) - return reordered_decoder_past From 71cd41a1b74355dc243cb5fa889df4aff024e666 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 15 Sep 2024 13:55:27 +0200 Subject: [PATCH 010/242] order of imports --- src/transformers/__init__.py | 12 ++++++------ src/transformers/models/__init__.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 0a86ab15a1e7..428c19f845fb 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -5556,7 +5556,6 @@ SwitchTransformersConfig, ) from .models.t5 import T5Config - from .models.timesfm import TimesFMConfig from .models.table_transformer import ( TableTransformerConfig, ) @@ -5567,6 +5566,7 @@ from .models.time_series_transformer import ( TimeSeriesTransformerConfig, ) + from .models.timesfm import TimesFMConfig from .models.timesformer import ( TimesformerConfig, ) @@ -7767,11 +7767,6 @@ T5PreTrainedModel, load_tf_weights_in_t5, ) - from .models.timesfm import ( - TimesFMForPrediction, - TimesFMModel, - TimesFMPreTrainedModel, - ) from .models.table_transformer import ( TableTransformerForObjectDetection, TableTransformerModel, @@ -7790,6 +7785,11 @@ TimeSeriesTransformerModel, TimeSeriesTransformerPreTrainedModel, ) + from .models.timesfm import ( + TimesFMForPrediction, + TimesFMModel, + TimesFMPreTrainedModel, + ) from .models.timesformer import ( TimesformerForVideoClassification, TimesformerModel, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 7236a0f52361..c2299c875c52 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -232,10 +232,10 @@ swinv2, switch_transformers, t5, - timesfm, table_transformer, tapas, time_series_transformer, + timesfm, timesformer, timm_backbone, trocr, From 5877388e4c1c5315e9dd2aaffb90394749eebeef Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Wed, 18 Sep 2024 17:27:31 -0700 Subject: [PATCH 011/242] copy from Google official implementation --- .../models/timesfm/configuration_timesfm.py | 124 +-- ...mesfm_original_tf_checkpoint_to_pytorch.py | 75 -- .../convert_timesfmx_checkpoint_to_flax.py | 299 ------ .../convert_timesfmx_checkpoint_to_pytorch.py | 279 ------ .../models/timesfm/modeling_timesfm.py | 940 +++--------------- .../models/timesfm/patched_decoder.py | 766 ++++++++++++++ .../models/timesfm/timesfm_base.py | 572 +++++++++++ src/transformers/models/timesfm/xreg_lib.py | 520 ++++++++++ 8 files changed, 2063 insertions(+), 1512 deletions(-) delete mode 100644 src/transformers/models/timesfm/convert_timesfm_original_tf_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_flax.py delete mode 100644 src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/timesfm/patched_decoder.py create mode 100644 src/transformers/models/timesfm/timesfm_base.py create mode 100644 src/transformers/models/timesfm/xreg_lib.py diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 16da290cc0cb..de82a874771b 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2020, The TimesFM Authors and HuggingFace Inc. +# Copyright 2024 Google LLC and HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,54 +36,47 @@ class TimesFMConfig(PretrainedConfig): Arguments: patch_len (`int`, *optional*, defaults to 32): - The length of each patch in the sequence. + The length of one patch in the input sequence. horizon_len (`int`, *optional*, defaults to 128): The length of the prediction horizon. - quantiles (`List[float]`, *optional*, defaults to `[0.1, 0.25, 0.5, 0.75, 0.9]`): - The quantiles to predict. - pad_val (`float`, *optional*, defaults to 1123581321.0): - The value used to pad the predictions. - tolerance (`float`, *optional*, defaults to 1e-6): - The tolerance for the quantile loss. + context_len (`int`, *optional*, defaults to 512): + The length of the input context. freq_size (`int`, *optional*, defaults to 3): The number of frequency embeddings. - d_model (`int`, *optional*, defaults to 1280): - Size of the encoder layers and the pooler layer. - d_kv (`int`, *optional*, defaults to 80): + model_dim (`int`, *optional*, defaults to 1280): + Size of the hidden layers in the feed-forward networks. + head_dim (`int`, *optional*, defaults to 80): Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will - be defined as `num_heads * d_kv`. - d_ff (`int`, *optional*, defaults to 1280): - Size of the intermediate feed forward layer in each `TimesFMBlock`. + be defined as `num_heads * head_dim`. num_layers (`int`, *optional*, defaults to 20): - Number of hidden layers in the Transformer encoder. - num_decoder_layers (`int`, *optional*): - Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set. + Number of Transformer layers. num_heads (`int`, *optional*, defaults to 16): Number of attention heads for each attention layer in the Transformer encoder. - relative_attention_num_buckets (`int`, *optional*, defaults to 32): - The number of buckets to use for each attention layer. - relative_attention_max_distance (`int`, *optional*, defaults to 128): - The maximum distance of the longer sequences for the bucket separation. + tolerance (`float`, *optional*, defaults to 1e-6): + The tolerance for the quantile loss. dropout_rate (`float`, *optional*, defaults to 0.1): The ratio for all dropout layers. classifier_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for classifier. - layer_norm_eps (`float`, *optional*, defaults to 1e-6): - The epsilon used by the layer normalization layers. + rms_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the RMS normalization layers. + quantiles (`List[float]`, *optional*, defaults to `[0.1, 0.25, 0.5, 0.75, 0.9]`): + The quantiles to predict. + pad_val (`float`, *optional*, defaults to 1123581321.0): + The value used to pad the predictions. + use_positional_embedding (`bool`, *optional*, defaults to `True`): + Whether to add positional embeddings. + per_core_batch_size (`int`, *optional*, defaults to 32): + The batch size per core for data parallelism. initializer_factor (`float`, *optional*, defaults to 1): A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing). - feed_forward_proj (`string`, *optional*, defaults to `"relu"`): - Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. TimesFMv1.1 uses the - `"gated-gelu"` feed forward projection. Original TimesFM uses `"relu"`. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). """ model_type = "timesfm" - keys_to_ignore_at_inference = ["past_key_values"] + keys_to_ignore_at_inference = [] attribute_map = { - "hidden_size": "d_model", + "hidden_size": "hidden_size", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers", } @@ -91,72 +84,41 @@ class TimesFMConfig(PretrainedConfig): def __init__( self, patch_len: int = 32, + context_len: int = 512, horizon_len: int = 128, - quantiles: List[float] = [0.1, 0.25, 0.5, 0.75, 0.9], - pad_val: float = 1123581321.0, + freq_size: int = 3, + num_layers: int = 20, + model_dim: int = 1280, + head_dim: int = 80, + num_heads: int = 16, + dropout_rate: float = 0.1, tolerance: float = 1e-6, - freq_size=3, - d_model=1280, - d_kv=80, - d_ff=1280, - num_layers=20, - num_decoder_layers=None, - num_heads=16, - relative_attention_num_buckets=32, - relative_attention_max_distance=128, - dropout_rate=0.1, - layer_norm_epsilon=1e-6, - initializer_factor=1.0, - feed_forward_proj="relu", - is_encoder_decoder=True, - use_cache=True, - pad_token_id=0, - eos_token_id=1, - classifier_dropout=0.0, + rms_norm_eps: float = 1e-6, + quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + pad_val: float = 1123581321.0, + use_positional_embedding: bool = True, + per_core_batch_size: int = 32, + initializer_factor: float = 1.0, **kwargs, ): self.patch_len = patch_len + self.context_len = context_len self.horizon_len = horizon_len self.quantiles = quantiles self.pad_val = pad_val - self.tolerance = tolerance self.freq_size = freq_size - self.d_model = d_model - self.d_kv = d_kv - self.d_ff = d_ff + self.model_dim = model_dim + self.head_dim = head_dim self.num_layers = num_layers - self.num_decoder_layers = ( - num_decoder_layers if num_decoder_layers is not None else self.num_layers - ) # default = symmetry self.num_heads = num_heads - self.relative_attention_num_buckets = relative_attention_num_buckets - self.relative_attention_max_distance = relative_attention_max_distance self.dropout_rate = dropout_rate - self.classifier_dropout = classifier_dropout - self.layer_norm_epsilon = layer_norm_epsilon + self.tolerance = tolerance + self.rms_norm_eps = rms_norm_eps + self.use_positional_embedding = use_positional_embedding + self.per_core_batch_size = per_core_batch_size self.initializer_factor = initializer_factor - self.feed_forward_proj = feed_forward_proj - self.use_cache = use_cache - - act_info = self.feed_forward_proj.split("-") - self.dense_act_fn = act_info[-1] - self.is_gated_act = act_info[0] == "gated" - - if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2: - raise ValueError( - f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. " - "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. " - "'gated-gelu' or 'relu'" - ) - - # for backwards compatibility - if feed_forward_proj == "gated-gelu": - self.dense_act_fn = "gelu_new" super().__init__( - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - is_encoder_decoder=is_encoder_decoder, **kwargs, ) diff --git a/src/transformers/models/timesfm/convert_timesfm_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/timesfm/convert_timesfm_original_tf_checkpoint_to_pytorch.py deleted file mode 100644 index b1ce727cac0c..000000000000 --- a/src/transformers/models/timesfm/convert_timesfm_original_tf_checkpoint_to_pytorch.py +++ /dev/null @@ -1,75 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The TimesFM authors and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert TimesFM checkpoint.""" - -import argparse - -from transformers import ( - TimesFMConfig, - TimesFMForConditionalGeneration, - load_tf_weights_in_timesfm, -) -from transformers.utils import logging - - -logging.set_verbosity_info() - - -def convert_tf_checkpoint_to_pytorch( - tf_checkpoint_path, config_file, pytorch_dump_path -): - # Initialise PyTorch model - config = TimesFMConfig.from_json_file(config_file) - print(f"Building PyTorch model from configuration: {config}") - model = TimesFMForConditionalGeneration(config) - - # Load weights from tf checkpoint - load_tf_weights_in_timesfm(model, config, tf_checkpoint_path) - - # Save pytorch-model - print(f"Save PyTorch model to {pytorch_dump_path}") - model.save_pretrained(pytorch_dump_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--tf_checkpoint_path", - default=None, - type=str, - required=True, - help="Path to the TensorFlow checkpoint path.", - ) - parser.add_argument( - "--config_file", - default=None, - type=str, - required=True, - help=( - "The config json file corresponding to the pre-trained TimesFM model. \nThis specifies the model architecture." - ), - ) - parser.add_argument( - "--pytorch_dump_path", - default=None, - type=str, - required=True, - help="Path to the output PyTorch model.", - ) - args = parser.parse_args() - convert_tf_checkpoint_to_pytorch( - args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path - ) diff --git a/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_flax.py b/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_flax.py deleted file mode 100644 index f9468ffb84c6..000000000000 --- a/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_flax.py +++ /dev/null @@ -1,299 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Convert TimesFMX checkpoints from the original repository to JAX/FLAX model.""" - -import argparse - -from timesfmx import checkpoints - -from transformers import FlaxTimesFMForConditionalGeneration, TimesFMConfig - - -def convert_timesfmx_checkpoint_to_flax( - timesfmx_checkpoint_path, config_name, flax_dump_folder_path -): - config = TimesFMConfig.from_pretrained(config_name) - flax_model = FlaxTimesFMForConditionalGeneration(config=config) - timesfmx_model = checkpoints.load_timesfmx_checkpoint(timesfmx_checkpoint_path) - - split_mlp_wi = "wi_0" in timesfmx_model["target"]["encoder"]["layers_0"]["mlp"] - - # Encoder - for layer_index in range(config.num_layers): - layer_name = f"layers_{str(layer_index)}" - - # Self-Attention - timesfmx_attention_key = timesfmx_model["target"]["encoder"][layer_name][ - "attention" - ]["key"]["kernel"] - timesfmx_attention_out = timesfmx_model["target"]["encoder"][layer_name][ - "attention" - ]["out"]["kernel"] - timesfmx_attention_query = timesfmx_model["target"]["encoder"][layer_name][ - "attention" - ]["query"]["kernel"] - timesfmx_attention_value = timesfmx_model["target"]["encoder"][layer_name][ - "attention" - ]["value"]["kernel"] - - # Layer Normalization - timesfmx_attention_layer_norm = timesfmx_model["target"]["encoder"][layer_name][ - "pre_attention_layer_norm" - ]["scale"] - - if split_mlp_wi: - timesfmx_mlp_wi_0 = timesfmx_model["target"]["encoder"][layer_name]["mlp"][ - "wi_0" - ]["kernel"] - timesfmx_mlp_wi_1 = timesfmx_model["target"]["encoder"][layer_name]["mlp"][ - "wi_1" - ]["kernel"] - else: - timesfmx_mlp_wi = timesfmx_model["target"]["encoder"][layer_name]["mlp"][ - "wi" - ]["kernel"] - - timesfmx_mlp_wo = timesfmx_model["target"]["encoder"][layer_name]["mlp"]["wo"][ - "kernel" - ] - - # Layer Normalization - timesfmx_mlp_layer_norm = timesfmx_model["target"]["encoder"][layer_name][ - "pre_mlp_layer_norm" - ]["scale"] - - # Assigning - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ - "SelfAttention" - ]["k"]["kernel"] = timesfmx_attention_key - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ - "SelfAttention" - ]["o"]["kernel"] = timesfmx_attention_out - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ - "SelfAttention" - ]["q"]["kernel"] = timesfmx_attention_query - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ - "SelfAttention" - ]["v"]["kernel"] = timesfmx_attention_value - - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ - "layer_norm" - ]["weight"] = timesfmx_attention_layer_norm - - if split_mlp_wi: - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"][ - "DenseReluDense" - ]["wi_0"]["kernel"] = timesfmx_mlp_wi_0 - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"][ - "DenseReluDense" - ]["wi_1"]["kernel"] = timesfmx_mlp_wi_1 - else: - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"][ - "DenseReluDense" - ]["wi"]["kernel"] = timesfmx_mlp_wi - - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"][ - "DenseReluDense" - ]["wo"]["kernel"] = timesfmx_mlp_wo - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"][ - "layer_norm" - ]["weight"] = timesfmx_mlp_layer_norm - - # Only for layer 0: - timesfmx_encoder_rel_embedding = timesfmx_model["target"]["encoder"]["relpos_bias"][ - "rel_embedding" - ].T - flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"][ - "relative_attention_bias" - ]["embedding"] = timesfmx_encoder_rel_embedding - - # Assigning - timesfmx_encoder_norm = timesfmx_model["target"]["encoder"]["encoder_norm"]["scale"] - flax_model.params["encoder"]["final_layer_norm"]["weight"] = timesfmx_encoder_norm - - # Decoder - for layer_index in range(config.num_decoder_layers): - layer_name = f"layers_{str(layer_index)}" - - # Self-Attention - timesfmx_attention_key = timesfmx_model["target"]["decoder"][layer_name][ - "self_attention" - ]["key"]["kernel"] - timesfmx_attention_out = timesfmx_model["target"]["decoder"][layer_name][ - "self_attention" - ]["out"]["kernel"] - timesfmx_attention_query = timesfmx_model["target"]["decoder"][layer_name][ - "self_attention" - ]["query"]["kernel"] - timesfmx_attention_value = timesfmx_model["target"]["decoder"][layer_name][ - "self_attention" - ]["value"]["kernel"] - - # Layer Normalization - timesfmx_pre_attention_layer_norm = timesfmx_model["target"]["decoder"][ - layer_name - ]["pre_self_attention_layer_norm"]["scale"] - - # Encoder-Decoder-Attention - timesfmx_enc_dec_attention_key = timesfmx_model["target"]["decoder"][ - layer_name - ]["encoder_decoder_attention"]["key"]["kernel"] - timesfmx_enc_dec_attention_out = timesfmx_model["target"]["decoder"][ - layer_name - ]["encoder_decoder_attention"]["out"]["kernel"] - timesfmx_enc_dec_attention_query = timesfmx_model["target"]["decoder"][ - layer_name - ]["encoder_decoder_attention"]["query"]["kernel"] - timesfmx_enc_dec_attention_value = timesfmx_model["target"]["decoder"][ - layer_name - ]["encoder_decoder_attention"]["value"]["kernel"] - - # Layer Normalization - timesfmx_cross_layer_norm = timesfmx_model["target"]["decoder"][layer_name][ - "pre_cross_attention_layer_norm" - ]["scale"] - - # MLP - if split_mlp_wi: - timesfmx_mlp_wi_0 = timesfmx_model["target"]["decoder"][layer_name]["mlp"][ - "wi_0" - ]["kernel"] - timesfmx_mlp_wi_1 = timesfmx_model["target"]["decoder"][layer_name]["mlp"][ - "wi_1" - ]["kernel"] - else: - timesfmx_mlp_wi = timesfmx_model["target"]["decoder"][layer_name]["mlp"][ - "wi" - ]["kernel"] - - timesfmx_mlp_wo = timesfmx_model["target"]["decoder"][layer_name]["mlp"]["wo"][ - "kernel" - ] - - # Layer Normalization - tx5_mlp_layer_norm = timesfmx_model["target"]["decoder"][layer_name][ - "pre_mlp_layer_norm" - ]["scale"] - - # Assigning - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ - "SelfAttention" - ]["k"]["kernel"] = timesfmx_attention_key - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ - "SelfAttention" - ]["o"]["kernel"] = timesfmx_attention_out - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ - "SelfAttention" - ]["q"]["kernel"] = timesfmx_attention_query - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ - "SelfAttention" - ]["v"]["kernel"] = timesfmx_attention_value - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ - "layer_norm" - ]["weight"] = timesfmx_pre_attention_layer_norm - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ - "EncDecAttention" - ]["k"]["kernel"] = timesfmx_enc_dec_attention_key - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ - "EncDecAttention" - ]["o"]["kernel"] = timesfmx_enc_dec_attention_out - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ - "EncDecAttention" - ]["q"]["kernel"] = timesfmx_enc_dec_attention_query - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ - "EncDecAttention" - ]["v"]["kernel"] = timesfmx_enc_dec_attention_value - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ - "layer_norm" - ]["weight"] = timesfmx_cross_layer_norm - - if split_mlp_wi: - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"][ - "DenseReluDense" - ]["wi_0"]["kernel"] = timesfmx_mlp_wi_0 - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"][ - "DenseReluDense" - ]["wi_1"]["kernel"] = timesfmx_mlp_wi_1 - else: - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"][ - "DenseReluDense" - ]["wi"]["kernel"] = timesfmx_mlp_wi - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"][ - "DenseReluDense" - ]["wo"]["kernel"] = timesfmx_mlp_wo - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"][ - "layer_norm" - ]["weight"] = tx5_mlp_layer_norm - - # Decoder Normalization - tx5_decoder_norm = timesfmx_model["target"]["decoder"]["decoder_norm"]["scale"] - flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm - - # Only for layer 0: - timesfmx_decoder_rel_embedding = timesfmx_model["target"]["decoder"]["relpos_bias"][ - "rel_embedding" - ].T - flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"][ - "relative_attention_bias" - ]["embedding"] = timesfmx_decoder_rel_embedding - - # Token Embeddings - tx5_token_embeddings = timesfmx_model["target"]["token_embedder"]["embedding"] - flax_model.params["shared"]["embedding"] = tx5_token_embeddings - - # LM Head (only in v1.1 checkpoints) - if "logits_dense" in timesfmx_model["target"]["decoder"]: - flax_model.params["lm_head"]["kernel"] = timesfmx_model["target"]["decoder"][ - "logits_dense" - ]["kernel"] - - flax_model.save_pretrained(flax_dump_folder_path) - print("TimesFMX Model was sucessfully converted!") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--timesfmx_checkpoint_path", - default=None, - type=str, - required=True, - help="Path the TX5 checkpoint.", - ) - parser.add_argument( - "--config_name", - default=None, - type=str, - required=True, - help="Config name of TimesFM model.", - ) - parser.add_argument( - "--flax_dump_folder_path", - default=None, - type=str, - required=True, - help="Path to the output FLAX model.", - ) - args = parser.parse_args() - convert_timesfmx_checkpoint_to_flax( - args.timesfmx_checkpoint_path, args.config_name, args.flax_dump_folder_path - ) diff --git a/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_pytorch.py b/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_pytorch.py deleted file mode 100644 index 8d5f13535e8d..000000000000 --- a/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_pytorch.py +++ /dev/null @@ -1,279 +0,0 @@ -# coding=utf-8 -# Copyright 2024 Google LLC and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Convert TimesFMX checkpoint to PyTorch - -Steps: -- Install gsutil according to https://cloud.google.com/storage/docs/gsutil_install -- Get a TimesFMX checkpoint at https://github.com/google-research/timesfmx/blob/main/docs/models.md#timesfm-11-checkpoints Example: - `gsutil -m cp -r gs://timesfm-data/pretrained_models/timesfmx/timesfm_1_1_small $HOME/` -- Create or download a corresponding config for the downloaded model. E.g. for TimesFM v1.1 small, you can use - https://huggingface.co/google/timesfm-v1_1-small/blob/main/config.json -- Convert: - ``` - python3 convert_timesfmx_checkpoint_to_pytorch.py --timesfmx_checkpoint_path=$HOME/timesfm_1_1_small --config_file=config.json\ - --pytorch_dump_path=$HOME/timesfm_1_1_small_pt - ``` -""" - -import argparse -import collections - -import torch -from flax import traverse_util -from timesfmx import checkpoints - -from transformers import ( - TimesFMConfig, - TimesFMEncoderModel, - TimesFMForConditionalGeneration, -) -from transformers.utils import logging - - -logging.set_verbosity_info() - - -def timesfmx_attention_lookup(params, i, prefix, layer_name="attention"): - """Returns the KOQV parameters of (self-)attention. Does not transpose.""" - k = params[f"{prefix}/layers_{i}/{layer_name}/key/kernel"] - o = params[f"{prefix}/layers_{i}/{layer_name}/out/kernel"] - q = params[f"{prefix}/layers_{i}/{layer_name}/query/kernel"] - v = params[f"{prefix}/layers_{i}/{layer_name}/value/kernel"] - return k, o, q, v - - -def timesfmx_mlp_lookup(params, i, prefix, split_mlp_wi=False): - """Returns the MLP parameters of a layer. Does not transpose.""" - if split_mlp_wi: - wi_0 = params[f"{prefix}/layers_{i}/mlp/wi_0/kernel"] - wi_1 = params[f"{prefix}/layers_{i}/mlp/wi_1/kernel"] - wi = (wi_0, wi_1) - else: - wi = params[f"{prefix}/layers_{i}/mlp/wi/kernel"] - - wo = params[f"{prefix}/layers_{i}/mlp/wo/kernel"] - return wi, wo - - -def timesfmx_layer_norm_lookup(params, i, prefix, layer_name): - """Returns the layer norm param of a layer.""" - return params[f"{prefix}/layers_{i}/{layer_name}/scale"] - - -def convert_timesfmx_to_pytorch( - variables: dict, *, num_layers: int, num_decoder_layers: int, is_encoder_only: bool -): - """Converts the parameters from TimesFMX-Flax to Transformers-PyTorch.""" - old = traverse_util.flatten_dict(variables["target"]) - old = {"/".join(k): v for k, v in old.items()} - - # v1.1 models have a gated GeLU with wi_0 and wi_1 instead of wi - split_mlp_wi = "encoder/layers_0/mlp/wi_0/kernel" in old - print("Split MLP:", split_mlp_wi) - - new = collections.OrderedDict() - - # Shared embeddings. - new["shared.weight"] = old["token_embedder/embedding"] - - # Encoder. - for i in range(num_layers): - # Block i, layer 0 (Self Attention). - layer_norm = timesfmx_layer_norm_lookup( - old, i, "encoder", "pre_attention_layer_norm" - ) - k, o, q, v = timesfmx_attention_lookup(old, i, "encoder", "attention") - new[f"encoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm - new[f"encoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T - new[f"encoder.block.{i}.layer.0.SelfAttention.o.weight"] = o.T - new[f"encoder.block.{i}.layer.0.SelfAttention.q.weight"] = q.T - new[f"encoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T - - # Block i, layer 1 (MLP). - layer_norm = timesfmx_layer_norm_lookup(old, i, "encoder", "pre_mlp_layer_norm") - wi, wo = timesfmx_mlp_lookup(old, i, "encoder", split_mlp_wi) - new[f"encoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm - if split_mlp_wi: - new[f"encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight"] = wi[0].T - new[f"encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight"] = wi[1].T - else: - new[f"encoder.block.{i}.layer.1.DenseReluDense.wi.weight"] = wi.T - new[f"encoder.block.{i}.layer.1.DenseReluDense.wo.weight"] = wo.T - - new["encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = old[ - "encoder/relpos_bias/rel_embedding" - ].T - new["encoder.final_layer_norm.weight"] = old["encoder/encoder_norm/scale"] - - if not is_encoder_only: - # Decoder. - for i in range(num_decoder_layers): - # Block i, layer 0 (Self Attention). - layer_norm = timesfmx_layer_norm_lookup( - old, i, "decoder", "pre_self_attention_layer_norm" - ) - k, o, q, v = timesfmx_attention_lookup(old, i, "decoder", "self_attention") - new[f"decoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm - new[f"decoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T - new[f"decoder.block.{i}.layer.0.SelfAttention.o.weight"] = o.T - new[f"decoder.block.{i}.layer.0.SelfAttention.q.weight"] = q.T - new[f"decoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T - - # Block i, layer 1 (Cross Attention). - layer_norm = timesfmx_layer_norm_lookup( - old, i, "decoder", "pre_cross_attention_layer_norm" - ) - k, o, q, v = timesfmx_attention_lookup( - old, i, "decoder", "encoder_decoder_attention" - ) - new[f"decoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm - new[f"decoder.block.{i}.layer.1.EncDecAttention.k.weight"] = k.T - new[f"decoder.block.{i}.layer.1.EncDecAttention.o.weight"] = o.T - new[f"decoder.block.{i}.layer.1.EncDecAttention.q.weight"] = q.T - new[f"decoder.block.{i}.layer.1.EncDecAttention.v.weight"] = v.T - - # Block i, layer 2 (MLP). - layer_norm = timesfmx_layer_norm_lookup( - old, i, "decoder", "pre_mlp_layer_norm" - ) - wi, wo = timesfmx_mlp_lookup(old, i, "decoder", split_mlp_wi) - new[f"decoder.block.{i}.layer.2.layer_norm.weight"] = layer_norm - if split_mlp_wi: - new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight"] = wi[0].T - new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight"] = wi[1].T - else: - new[f"decoder.block.{i}.layer.2.DenseReluDense.wi.weight"] = wi.T - new[f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight"] = wo.T - - new["decoder.final_layer_norm.weight"] = old["decoder/decoder_norm/scale"] - new["decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = ( - old["decoder/relpos_bias/rel_embedding"].T - ) - - # LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead) - if "decoder/logits_dense/kernel" in old: - new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T - - return new - - -def make_state_dict(converted_params, is_encoder_only: bool): - """Prepares a state dict for the PyTorch model.""" - # Make a state dict with torch tensors. - state_dict = collections.OrderedDict( - [(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()] - ) - - # Add what is missing. - if "encoder.embed_tokens.weight" not in state_dict: - state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"] - - if not is_encoder_only: - if "decoder.embed_tokens.weight" not in state_dict: - state_dict["decoder.embed_tokens.weight"] = state_dict["shared.weight"] - - if "lm_head.weight" not in state_dict: # For old 1.0 models. - print("Using shared word embeddings as lm_head.") - state_dict["lm_head.weight"] = state_dict["shared.weight"] - - return state_dict - - -def load_timesfmx_weights_in_timesfm( - model, config, timesfmx_checkpoint_path, is_encoder_only -): - """Replaces the params in model witht the TimesFMX converted params.""" - variables = checkpoints.load_timesfmx_checkpoint(timesfmx_checkpoint_path) - converted = convert_timesfmx_to_pytorch( - variables, - num_layers=config.num_layers, - num_decoder_layers=config.num_decoder_layers, - is_encoder_only=is_encoder_only, - ) - state_dict = make_state_dict(converted, is_encoder_only) - model.load_state_dict(state_dict, strict=True) - - -def convert_timesfmx_checkpoint_to_pytorch( - timesfmx_checkpoint_path, - config_file, - pytorch_dump_path, - is_encoder_only: bool = False, -): - """Loads the config and model, converts the TimesFMX checkpoint, and saves a PyTorch checkpoint.""" - # Initialise PyTorch model - config = TimesFMConfig.from_json_file(config_file) - print(f"Building PyTorch model from configuration: {config}") - # Non-v1.1 checkpoints could also use TimesFMModel, but this works for all. - # The v1.0 checkpoints will simply have an LM head that is the word embeddings. - if is_encoder_only: - model = TimesFMEncoderModel(config) - else: - model = TimesFMForConditionalGeneration(config) - - # Load weights from tf checkpoint - load_timesfmx_weights_in_timesfm( - model, config, timesfmx_checkpoint_path, is_encoder_only - ) - - # Save pytorch-model - print(f"Save PyTorch model to {pytorch_dump_path}") - model.save_pretrained(pytorch_dump_path) - - # Verify that we can load the checkpoint. - model.from_pretrained(pytorch_dump_path) - print("Done") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Converts a native TimesFMX checkpoint into a PyTorch checkpoint." - ) - # Required parameters - parser.add_argument( - "--timesfmx_checkpoint_path", - default=None, - type=str, - required=True, - help="Path to the TimesFMX checkpoint.", - ) - parser.add_argument( - "--config_file", - default=None, - type=str, - required=True, - help="The config json file corresponding to the pre-trained TimesFM model.\nThis specifies the model architecture.", - ) - parser.add_argument( - "--pytorch_dump_path", - default=None, - type=str, - required=True, - help="Path to the output PyTorch model.", - ) - parser.add_argument( - "--is_encoder_only", - action="store_true", - help="Check if the model is encoder-decoder model", - default=False, - ) - args = parser.parse_args() - convert_timesfmx_checkpoint_to_pytorch( - args.timesfmx_checkpoint_path, - args.config_file, - args.pytorch_dump_path, - args.is_encoder_only, - ) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 852d91320889..ea27c1e75b8c 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2024 Mesh TensorFlow authors, TimesFM Authors and HuggingFace Inc. team. +# Copyright 2024 Google LLC and HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,805 +14,189 @@ # limitations under the License. """PyTorch TimesFM model.""" -import copy -import math -import os -import warnings -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from ...activations import ACT2FN -from ...modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPastAndCrossAttentions, - Seq2SeqLMOutput, - Seq2SeqModelOutput, -) -from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import ( - ALL_LAYERNORM_LAYERS, - find_pruneable_heads_and_indices, - prune_linear_layer, -) -from ...utils import ( - DUMMY_INPUTS, - DUMMY_MASK, - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_torch_fx_proxy, - logging, - replace_return_docstrings, -) -from ...utils.model_parallel_utils import assert_device_map, get_device_map -from .configuration_timesfm import TimesFMConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "TimesFMConfig" -_CHECKPOINT_FOR_DOC = "google/timesfm-1.0-200m" - #################################################### # PyTorch Models are constructed by sub-classing # - torch.nn.Module for the layers and # - PreTrainedModel for the models (it-self a sub-class of nn.Module) #################################################### -PARALLELIZE_DOCSTRING = r""" - This is an experimental feature and is a subject to change at a moment's notice. - - Uses a device map to distribute attention modules of the model across several devices. If no device map is given, - it will evenly distribute blocks across all devices. - - Args: - device_map (`Dict[int, list]`, *optional*): - A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always - automatically mapped to the first device (for esoteric reasons). That means that the first device should - have fewer attention modules mapped to it than other devices. For reference, the timesfm models have the - following number of attention modules: - - - google/timesfm-1.0-200m: 6 - - google-timesfm/timesfm-base: 12 - - google-timesfm/timesfm-large: 24 - - google-timesfm/timesfm-3b: 24 - - google-timesfm/timesfm-11b: 24 - - Example: - - ```python - # Here is an example of a device map on a machine with 4 GPUs using google-timesfm/timesfm-3b, which has a total of 24 attention modules: - model = TimesFMForConditionalGeneration.from_pretrained("google-timesfm/timesfm-3b") - device_map = { - 0: [0, 1, 2], - 1: [3, 4, 5, 6, 7, 8, 9], - 2: [10, 11, 12, 13, 14, 15, 16], - 3: [17, 18, 19, 20, 21, 22, 23], - } - model.parallelize(device_map) - ``` -""" -DEPARALLELIZE_DOCSTRING = r""" - Moves the model to cpu from a model parallel state. - - Example: - - ```python - # On a 4 GPU machine with google-timesfm/timesfm-3b: - model = TimesFMForConditionalGeneration.from_pretrained("google-timesfm/timesfm-3b") - device_map = { - 0: [0, 1, 2], - 1: [3, 4, 5, 6, 7, 8, 9], - 2: [10, 11, 12, 13, 14, 15, 16], - 3: [17, 18, 19, 20, 21, 22, 23], - } - model.parallelize(device_map) # Splits the model across several devices - model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() - ``` -""" - - -# Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->TimesFM -class TimesFMLayerNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Construct a layernorm module in the TimesFM style. No bias and no subtraction of mean. - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - # TimesFM uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean - # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated - # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for - # half-precision inputs is done in fp32 - - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) +import logging +from os import path +from typing import Any, Sequence - return self.weight * hidden_states - - -try: - from apex.normalization import FusedRMSNorm - - TimesFMLayerNorm = FusedRMSNorm # noqa - - logger.info( - "Discovered apex.normalization.FusedRMSNorm - will use it instead of TimesFMLayerNorm" - ) -except ImportError: - # using the normal TimesFMLayerNorm - pass -except Exception: - logger.warning( - "discovered apex but it failed to load, falling back to TimesFMLayerNorm" - ) - pass - -ALL_LAYERNORM_LAYERS.append(TimesFMLayerNorm) - - -class TimesFMResidualBlock(nn.Module): - def __init__(self, input_dims, hidden_dims, output_dims, dropout=0.1): - super().__init__() - - self.hidden_layer = nn.Sequential(nn.Linear(input_dims, hidden_dims), nn.SiLU()) - self.output_layer = nn.Linear(hidden_dims, output_dims) - self.residual_layer = nn.Linear(input_dims, output_dims) - self.dropout = nn.Dropout(dropout) - - def forward(self, inputs): - hidden = self.hidden_layer(inputs) - output = self.output_layer(hidden) - output = self.dropout(output) - residual = self.residual_layer(inputs) - - return output + residual - - -class TimesFMPositionalEmbedding(nn.Module): - """Generates position embedding for a given 1-d sequence. - - Attributes: - min_timescale: Start of the geometric index. Determines the periodicity of - the added signal. - max_timescale: End of the geometric index. Determines the frequency of the - added signal. - embedding_dims: Dimension of the embedding to be generated. - """ +import numpy as np +import torch +from huggingface_hub import snapshot_download +import timesfm_base +import patched_decoder as ppd +from ...modeling_utils import PreTrainedModel - def __init__(self, min_timescale=1, max_timescale=10000, embedding_dims=0): - super().__init__() - self.min_timescale = min_timescale - self.max_timescale = max_timescale - self.embedding_dims = embedding_dims +_TOL = 1e-6 - def forward(self, seq_length=None, position=None): - """Generates a tensor of sinusoids with different frequencies. - Args: - seq_length: an optional Python int defining the output sequence length. - if the `position` argument is specified. - position: [B, seq_length], optional position for each token in the - sequence, only required when the sequence is packed. +class TimesFmTorch(PreTrainedModel, timesfm_base.TimesFmBase): + """TimesFM forecast API for inference.""" - Returns: - [B, seqlen, D] if `position` is specified, else [1, seqlen, D] - """ - if position is None: - if seq_length is None: - raise ValueError("If position is None, seq_length should be specified.") - # [1, seqlen] - position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0) - else: - if position.ndim != 2: - raise ValueError( - f"position should have 2 dimensions, got {position.ndim}" - ) - - num_timescales = self.embedding_dims // 2 - log_timescale_increment = math.log( - float(self.max_timescale) / float(self.min_timescale) - ) / max(torch.tensor(num_timescales, dtype=torch.float32) - 1, 1) - inv_timescales = self.min_timescale * torch.exp( - torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment + def __post_init__(self): + self._model_config = ppd.TimesFMConfig( + num_layers=self.num_layers, + num_heads=self.num_heads, + hidden_size=self.model_dims, + intermediate_size=self.model_dims, + patch_len=self.input_patch_len, + horizon_len=self.output_patch_len, + head_dim=self.model_dims // self.num_heads, + quantiles=self.quantiles, ) - scaled_time = position.unsqueeze(2) * inv_timescales.unsqueeze(0).unsqueeze(0) - signal = torch.cat( - [torch.sin(scaled_time), torch.cos(scaled_time)], dim=2 - ).type(torch.float32) - - signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) - return signal - - -class TimesFMDenseActDense(nn.Module): - def __init__(self, config: TimesFMConfig): - super().__init__() - self.wi = nn.Linear(config.d_model, config.d_ff, bias=True) - self.wo = nn.Linear(config.d_ff, config.d_model, bias=True) - self.dropout = nn.Dropout(config.dropout_rate) - self.act = ACT2FN[config.dense_act_fn] - - def forward(self, hidden_states): - hidden_states = self.wi(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.dropout(hidden_states) - if ( - isinstance(self.wo.weight, torch.Tensor) - and hidden_states.dtype != self.wo.weight.dtype - and self.wo.weight.dtype != torch.int8 - ): - hidden_states = hidden_states.to(self.wo.weight.dtype) - hidden_states = self.wo(hidden_states) - return hidden_states - - -class TimesFMLayerFF(nn.Module): - def __init__(self, config: TimesFMConfig): - super().__init__() - - self.DenseReluDense = TimesFMDenseActDense(config) - self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon) - self.dropout = nn.Dropout(config.dropout_rate) - - def forward(self, hidden_states): - forwarded_states = self.layer_norm(hidden_states) - forwarded_states = self.DenseReluDense(forwarded_states) - hidden_states = hidden_states + self.dropout(forwarded_states) - return hidden_states - - -class TimesFMPerHeadDimScale(nn.Module): - def __init__(self, config: TimesFMConfig): - super().__init__() - dim = config.d_model // config.num_heads - r_softplus_0 = 1.442695041 - self.scale_factor = r_softplus_0 / math.sqrt(dim) - self.scale = nn.Parameter(torch.empty(self.dim)) - - def forward(self, hidden_states): - scale = self.scale_factor * F.softplus(self.scale) - return hidden_states * scale - - -class TimesFMAttention(nn.Module): - def __init__(self, config: TimesFMConfig, has_relative_attention_bias=False): - super().__init__() - self.is_decoder = config.is_decoder - self.has_relative_attention_bias = has_relative_attention_bias - self.relative_attention_num_buckets = config.relative_attention_num_buckets - self.relative_attention_max_distance = config.relative_attention_max_distance - self.d_model = config.d_model - self.key_value_proj_dim = config.d_kv - self.n_heads = config.num_heads - self.dropout = config.dropout_rate - self.inner_dim = self.n_heads * self.key_value_proj_dim - - # Mesh TensorFlow initialization to avoid scaling before softmax - self.q = nn.Linear(self.d_model, self.inner_dim, bias=True) - self.k = nn.Linear(self.d_model, self.inner_dim, bias=True) - self.v = nn.Linear(self.d_model, self.inner_dim, bias=True) - self.o = nn.Linear(self.inner_dim, self.d_model, bias=True) - self.per_head_dim_scale = TimesFMPerHeadDimScale(config) - - if self.has_relative_attention_bias: - self.relative_attention_bias = nn.Embedding( - self.relative_attention_num_buckets, self.n_heads - ) - self.pruned_heads = set() - self.gradient_checkpointing = False - - def prune_heads(self, heads): - if len(heads) == 0: - return - heads, index = find_pruneable_heads_and_indices( - heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + self._model = None + self.num_cores = 1 + self.global_batch_size = self.per_core_batch_size + self._device = torch.device( + "cuda:0" if (torch.cuda.is_available() and self.backend == "gpu") else "cpu" ) - # Prune linear layers - self.q = prune_linear_layer(self.q, index) - self.k = prune_linear_layer(self.k, index) - self.v = prune_linear_layer(self.v, index) - self.o = prune_linear_layer(self.o, index, dim=1) - # Update hyper params - self.n_heads = self.n_heads - len(heads) - self.inner_dim = self.key_value_proj_dim * self.n_heads - self.pruned_heads = self.pruned_heads.union(heads) - - @staticmethod - def _relative_position_bucket( - relative_position, bidirectional=True, num_buckets=32, max_distance=128 - ): - """ - Adapted from Mesh Tensorflow: - https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 - Translate relative position to a bucket number for relative attention. The relative position is defined as - memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for - small absolute relative_position and larger buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. - This should allow for more graceful generalization to longer sequences than the model has been trained on + def load_from_checkpoint( + self, + checkpoint: timesfm_base.TimesFmCheckpoint, + ) -> None: + """Loads a checkpoint and compiles the decoder.""" + checkpoint_path = checkpoint.path + repo_id = checkpoint.huggingface_repo_id + if checkpoint_path is None: + checkpoint_path = path.join(snapshot_download(repo_id), "torch_model.ckpt") + self._model = ppd.PatchedTimeSeriesDecoder(self._model_config) + loaded_checkpoint = torch.load(checkpoint_path, weights_only=True) + logging.info("Loading checkpoint from %s", checkpoint_path) + self._model.load_state_dict(loaded_checkpoint) + logging.info("Sending checkpoint to device %s", f"{self._device}") + self._model.to(self._device) + self._model.eval() + # TODO: add compilation. + + def forecast( + self, + inputs: Sequence[Any], + freq: Sequence[int] | None = None, + window_size: int | None = None, + forecast_context_len: int | None = None, + return_forecast_on_context: bool = False, + truncate_negative: bool = False, + ) -> tuple[np.ndarray, np.ndarray]: + """Forecasts on a list of time series. Args: - relative_position: an int32 Tensor - bidirectional: a boolean - whether the attention is bidirectional - num_buckets: an integer - max_distance: an integer + inputs: list of time series forecast contexts. Each context time series + should be in a format convertible to JTensor by `jnp.array`. + freq: frequency of each context time series. 0 for high frequency + (default), 1 for medium, and 2 for low. Notice this is different from + the `freq` required by `forecast_on_df`. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + forecast_context_len: optional max context length. + return_forecast_on_context: True to return the forecast on the context + when available, i.e. after the first input patch. + truncate_negative: truncate to only non-negative values if all the contexts + have non-negative values. Returns: - a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) - """ - relative_buckets = 0 - if bidirectional: - num_buckets //= 2 - relative_buckets += (relative_position > 0).to(torch.long) * num_buckets - relative_position = torch.abs(relative_position) - else: - relative_position = -torch.min( - relative_position, torch.zeros_like(relative_position) - ) - # now relative_position is in the range [0, inf) - - # half of the buckets are for exact increments in positions - max_exact = num_buckets // 2 - is_small = relative_position < max_exact - - # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance - relative_position_if_large = max_exact + ( - torch.log(relative_position.float() / max_exact) - / math.log(max_distance / max_exact) - * (num_buckets - max_exact) - ).to(torch.long) - relative_position_if_large = torch.min( - relative_position_if_large, - torch.full_like(relative_position_if_large, num_buckets - 1), - ) - - relative_buckets += torch.where( - is_small, relative_position, relative_position_if_large - ) - return relative_buckets + A tuple for JTensors: + - the mean forecast of size (# inputs, # forecast horizon), + - the full forecast (mean + quantiles) of size + (# inputs, # forecast horizon, 1 + # quantiles). - def compute_bias(self, query_length, key_length, device=None): - """Compute binned relative position bias""" - if device is None: - device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[ - :, None - ] - memory_position = torch.arange(key_length, dtype=torch.long, device=device)[ - None, : - ] - relative_position = ( - memory_position - context_position - ) # shape (query_length, key_length) - relative_position_bucket = self._relative_position_bucket( - relative_position, # shape (query_length, key_length) - bidirectional=(not self.is_decoder), - num_buckets=self.relative_attention_num_buckets, - max_distance=self.relative_attention_max_distance, - ) - values = self.relative_attention_bias( - relative_position_bucket - ) # shape (query_length, key_length, num_heads) - values = values.permute([2, 0, 1]).unsqueeze( - 0 - ) # shape (1, num_heads, query_length, key_length) - return values - - def forward( - self, - hidden_states, - mask=None, - key_value_states=None, - position_bias=None, - past_key_value=None, - layer_head_mask=None, - query_length=None, - use_cache=False, - output_attentions=False, - ): + Raises: + ValueError: If the checkpoint is not properly loaded. """ - Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). - """ - # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) - batch_size, seq_length = hidden_states.shape[:2] - - real_seq_length = seq_length - - if past_key_value is not None: - if len(past_key_value) != 2: - raise ValueError( - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - ) - real_seq_length += ( - past_key_value[0].shape[2] if query_length is None else query_length + if not self._model: + raise ValueError( + "Checkpoint not loaded. Call `load_from_checkpoint` before" + " `forecast`." ) - - key_length = ( - real_seq_length if key_value_states is None else key_value_states.shape[1] - ) - - def shape(states): - """projection""" - return states.view( - batch_size, -1, self.n_heads, self.key_value_proj_dim - ).transpose(1, 2) - - def unshape(states): - """reshape""" - return ( - states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - ) - - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - unscaled_query_states = shape( - self.q(hidden_states) - ) # (batch_size, n_heads, seq_length, dim_per_head) - query_states = self.per_head_dim_scale(unscaled_query_states) - - # get key/value states - key_states = project( - hidden_states, - self.k, - key_value_states, - past_key_value[0] if past_key_value is not None else None, - ) - value_states = project( - hidden_states, - self.v, - key_value_states, - past_key_value[1] if past_key_value is not None else None, - ) - - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 - - if position_bias is None: - if not self.has_relative_attention_bias: - position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), - device=scores.device, - dtype=scores.dtype, + if forecast_context_len is None: + fcontext_len = self.context_len + else: + fcontext_len = forecast_context_len + inputs = [np.array(ts)[-fcontext_len:] for ts in inputs] + inp_min = np.min([np.min(ts) for ts in inputs]) + + if window_size is not None: + new_inputs = [] + for ts in inputs: + new_inputs.extend(timesfm_base.moving_average(ts, window_size)) + inputs = new_inputs + + if freq is None: + logging.info("No frequency provided via `freq`. Default to high (0).") + freq = [0] * len(inputs) + + input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq) + with torch.no_grad(): + mean_outputs = [] + full_outputs = [] + assert input_ts.shape[0] % self.global_batch_size == 0 + for i in range(input_ts.shape[0] // self.global_batch_size): + input_ts_in = torch.from_numpy( + np.array( + input_ts[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size + ], + dtype=np.float32, + ) + ).to(self._device) + input_padding_in = torch.from_numpy( + np.array( + input_padding[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size + ], + dtype=np.float32, + ) + ).to(self._device) + inp_freq_in = ( + torch.from_numpy( + np.array( + inp_freq[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size, + :, + ], + dtype=np.int32, + ) + ) + .long() + .to(self._device) ) - if self.gradient_checkpointing and self.training: - position_bias.requires_grad = True - else: - position_bias = self.compute_bias( - real_seq_length, key_length, device=scores.device + mean_output, full_output = self._model.decode( + input_ts=input_ts_in, + paddings=input_padding_in, + freq=inp_freq_in, + horizon_len=self.horizon_len, + return_forecast_on_context=return_forecast_on_context, ) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] - - if mask is not None: - position_bias = ( - position_bias + mask - ) # (batch_size, n_heads, seq_length, key_length) - - if self.pruned_heads: - mask = torch.ones(position_bias.shape[1]) - mask[list(self.pruned_heads)] = 0 - position_bias_masked = position_bias[:, mask.bool()] - else: - position_bias_masked = position_bias - - scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) - - # Mask heads if we want to - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask - - attn_output = unshape( - torch.matmul(attn_weights, value_states) - ) # (batch_size, seq_length, dim) - attn_output = self.o(attn_output) - - present_key_value_state = ( - (key_states, value_states) if (self.is_decoder and use_cache) else None - ) - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) - - if output_attentions: - outputs = outputs + (attn_weights,) - return outputs - - -class TimesFMTransformerLayer(nn.Module): - def __init__(self, config: TimesFMConfig): - super().__init__() - self.attention = TimesFMAttention(config) - self.ff = TimesFMLayerFF(config) - self.layer_norm = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) - self.dropout = nn.Dropout(config.dropout_rate) - - def forward(self, inputs, mask=None): - x = self.layer_norm(inputs) - x = self.attention(x, mask=mask) - x = self.dropout(x) - x = x + inputs - x = self.ff(x) - return x - - -class TimesFMTransformerStack(nn.Module): - def __init__(self, config: TimesFMConfig): - super().__init__() - self.layers = nn.ModuleList( - [TimesFMTransformerLayer(config) for _ in range(config.num_layers)] - ) - - def forward(self, hidden_states, mask=None): - for layer in self.layers: - hidden_states = layer(hidden_states, mask=mask) - return hidden_states - - -class TimesFMModel(PreTrainedModel): - def __init__(self, config: TimesFMConfig): - super().__init__(config) - - self.freq_emb = nn.Embedding( - num_embeddings=config.freq_size, - embedding_dim=config.d_model, - ) - self.position_emb = TimesFMPositionalEmbedding( - embedding_dims=config.d_model, - ) - - self.input_ff_layer = TimesFMResidualBlock( - input_dims=config.patch_len * 2, - output_dims=config.d_model, - hidden_dims=config.d_ff, - dropout=config.dropout_rate, - ) - - self.stacked_transformer_layer = TimesFMTransformerStack(config) - - def preprocess_inputs(self, inputs): - assert len(inputs.shape) == 3 # (batch_size, num_patches, patch_len) - inputs_mean = inputs.mean(dim=(1, 2)) - inputs_std = inputs.std(dim=(1, 2)) - processed_input = (inputs - inputs_mean[:, None, None]) / inputs_std[ - :, None, None - ] - return processed_input, (inputs_mean, inputs_std) - - def create_causal_mask(batch_size, seq_len): - mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool() - mask = mask.unsqueeze(0).unsqueeze(1) - mask = mask.expand(batch_size, 1, seq_len, seq_len) - mask = mask.float().masked_fill(mask, -2.3819763e38).masked_fill(~mask, 0.0) - return mask - - def forward( - self, - input_ts, - ): - batch_size = input_ts.shape[0] - patched_inputs = input_ts.reshape(batch_size, -1, self.config.patch_len) - patched_pads = torch.zeros_like(patched_inputs) - patched_inputs, input_stats = self.preprocess_inputs(patched_inputs) - concat_inputs = torch.concat([patched_inputs, patched_pads], dim=-1) - - model_input = self.input_ff_layer(concat_inputs) - position_emb = self.position_emb(seq_length=model_input.shape[1]).expand( - model_input.shape[0], -1, -1 - ) - model_input = model_input + position_emb - f_emb = self.freq_emb( - torch.zeros((batch_size, 1), dtype=torch.long) - ) # freq set to zero, change if needed - model_input = model_input + f_emb - mask = self.create_causal_mask(model_input.shape[0], model_input.shape[1]) - model_output = self.stacked_transformer_layer(model_input, mask=mask) - return model_output, input_stats - - -class TimesFMPredictionHead(nn.Module): - def __init__(self, config: TimesFMConfig): - super().__init__() - self.config = config - self.horizon_ff_layer = TimesFMResidualBlock( - input_dims=config.d_model, - output_dims=config.horizon_len, - hidden_dims=config.d_ff, - dropout=config.dropout_rate, - ) - - def postprocess_outputs(self, outputs, stats): - mean, std = stats - return outputs * std[:, None, None, None] + mean[:, None, None, None] - - def forward(self, model_output, input_stats): - batch_size = model_output.shape[0] - output_ts = self.horizon_ff_layer(model_output) - - assert self.config.d_model % self.config.horizon_len == 0 - num_outputs = self.config.d_model // self.config.horizon_len - - output_ts = output_ts.reshape( - batch_size, -1, self.config.horizon_len, num_outputs - ) - output_ts = self.postprocess_outputs(output_ts, input_stats) - return output_ts - - -class TimesFMForPrediction(PreTrainedModel): - def __init__(self, config: TimesFMConfig): - super().__init__(config) - self.timesfm = TimesFMModel(config) - self.prediction_head = TimesFMPredictionHead(config) - - def forward( - self, - input_ts, - ): - model_output, input_stats = self.timesfm(input_ts) - output_ts = self.prediction_head(model_output, input_stats) - return output_ts - - -TIMESFM_START_DOCSTRING = r""" - - The TIMESFM model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text - Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan - Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a - text-to-text denoising generative setting. - - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`TimesFMConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -TIMESFM_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. TIMESFM is a model with relative position embeddings so you - should be able to pad the inputs on both the right and the left. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for detail. - - [What are input IDs?](../glossary#input-ids) - - To know more on how to prepare `input_ids` for pretraining take a look a [TIMESFM Training](./timesfm#training). - attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - TIMESFM uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` - is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). - - To know more on how to prepare `decoder_input_ids` for pretraining take a look at [TIMESFM - Training](./timesfm#training). - decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, - 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, - 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in - `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): - Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at - the output of the last layer of the encoder. Used in the cross-attention of the decoder. - past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded - representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be - input (see `past_key_values`). This is useful if you want more control over how to convert - `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - - If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value - of `inputs_embeds`. - - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" + mean_output = mean_output.detach().cpu().numpy() + full_output = full_output.detach().cpu().numpy() + mean_output = np.array(mean_output) + full_output = np.array(full_output) + mean_outputs.append(mean_output) + full_outputs.append(full_output) + + mean_outputs = np.concatenate(mean_outputs, axis=0) + full_outputs = np.concatenate(full_outputs, axis=0) + + if pmap_pad > 0: + mean_outputs = mean_outputs[:-pmap_pad, ...] + full_outputs = full_outputs[:-pmap_pad, ...] + + if window_size is not None: + mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] + full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] + if inp_min >= 0 and truncate_negative: + mean_outputs = np.maximum(mean_outputs, 0.0) + full_outputs = np.maximum(full_outputs, 0.0) + return mean_outputs, full_outputs diff --git a/src/transformers/models/timesfm/patched_decoder.py b/src/transformers/models/timesfm/patched_decoder.py new file mode 100644 index 000000000000..f7e108bc08d8 --- /dev/null +++ b/src/transformers/models/timesfm/patched_decoder.py @@ -0,0 +1,766 @@ +# coding=utf-8 +# Copyright 2024 Google LLC and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pytorch version of patched decoder.""" + + +import math +from typing import List, Tuple +import torch +from torch import nn +import torch.nn.functional as F +from transformers.models.timesfm.configuration_timesfm import TimesFMConfig + + +def _masked_mean_std( + inputs: torch.Tensor, padding: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """Calculates mean and standard deviation of `inputs` across axis 1. + + It excludes values where `padding` is 1. + + Args: + inputs: A PyTorch tensor of shape [b, n, p]. + padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1. + + Returns: + A tuple containing the mean and standard deviation. + We return the statistics of the first patch with more than three non-padded + values. + """ + # Selecting the first patch with more than 3 unpadded values. + pad_sum = torch.sum(1 - padding, dim=2) + + def _get_patch_index(arr: torch.Tensor): + indices = torch.argmax((arr >= 3).to(torch.int32), dim=1) + row_sum = (arr >= 3).to(torch.int32).sum(dim=1) + return torch.where(row_sum == 0, arr.shape[1] - 1, indices) + + patch_indices = _get_patch_index(pad_sum) + bidxs = torch.arange(inputs.shape[0]) + + arr = inputs[bidxs, patch_indices, :] + pad = padding[bidxs, patch_indices, :] + + # Create a mask where padding is 0 + mask = 1 - pad + + # Calculate the number of valid elements + num_valid_elements = torch.sum(mask, dim=1) + num_valid_elements = torch.where( + num_valid_elements == 0, + torch.tensor( + 1, dtype=num_valid_elements.dtype, device=num_valid_elements.device + ), + num_valid_elements, + ) + + # Calculate the masked sum and squared sum + masked_sum = torch.sum(arr * mask, dim=1) + masked_squared_sum = torch.sum((arr * mask) ** 2, dim=1) + + # Calculate the masked mean and standard deviation + masked_mean = masked_sum / num_valid_elements + masked_var = masked_squared_sum / num_valid_elements - masked_mean**2 + masked_var = torch.where( + masked_var < 0.0, + torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device), + masked_var, + ) + masked_std = torch.sqrt(masked_var) + + return masked_mean, masked_std + + +def _shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor: + """Shifts rows of seq based on the first 0 in each row of the mask. + + Args: + mask: mask tensor of shape [B, N] + seq: seq tensor of shape [B, N, P] + + Returns: + Returns the shifted sequence. + """ + batch_size, num_seq, feature_dim = seq.shape + + new_mask: torch.BoolTensor = mask == 0 + + # Use argmax to find the first True value in each row + indices = new_mask.to(torch.int32).argmax(dim=1) + + # Handle rows with all zeros + indices[~new_mask.any(dim=1)] = -1 + + # Create index ranges for each sequence in the batch + idx_range = ( + torch.arange(num_seq) + .to(seq.device) + .unsqueeze(0) + .unsqueeze(-1) + .expand(batch_size, -1, feature_dim) + ) + + # Calculate shifted indices for each element in each sequence + shifted_idx = (idx_range - indices[:, None, None]) % num_seq + + # Gather values from seq using shifted indices + shifted_seq = seq.gather(1, shifted_idx) + + return shifted_seq + + +def get_large_negative_number(dtype: torch.dtype) -> torch.Tensor: + """Returns a large negative value for the given dtype.""" + if dtype.is_floating_point: + dtype_max = torch.finfo(dtype).max + else: + dtype_max = torch.iinfo(dtype).max + return torch.tensor(-0.7 * dtype_max, dtype=dtype) + + +def apply_mask_to_logits(logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Applies a floating-point mask to a set of logits. + + Args: + logits: A torch.Tensor of logit values. + mask: A torch.Tensor (float32) of mask values with the encoding described + in the function documentation. + + Returns: + Masked logits. + """ + + min_value = get_large_negative_number(logits.dtype) + + return torch.where((mask >= min_value * 0.5), logits, min_value) + + +def convert_paddings_to_mask( + paddings: torch.Tensor, dtype: torch.dtype = torch.float32 +) -> torch.Tensor: + """Converts binary paddings to a logit mask ready to add to attention matrix. + + Args: + paddings: binary torch.Tensor of shape [B, T], with 1 denoting padding + token. + dtype: data type of the input. + + Returns: + A torch.Tensor of shape [B, 1, 1, T] ready to add to attention logits. + """ + attention_mask = paddings.detach().clone() + attention_mask = attention_mask[:, None, None, :] # Equivalent to jnp.newaxis + attention_mask *= get_large_negative_number(dtype) + return attention_mask + + +def causal_mask(input_t: torch.Tensor) -> torch.Tensor: + """Computes and returns causal mask. + + Args: + input_t: A torch.Tensor of shape [B, T, D]. + + Returns: + An attention_mask torch.Tensor of shape [1, 1, T, T]. Attention mask has + already been converted to large negative values. + """ + assert input_t.dtype.is_floating_point, input_t.dtype + large_negative_number = get_large_negative_number(input_t.dtype) + t = input_t.shape[1] + col_idx = torch.arange(t).unsqueeze(0).repeat(t, 1) + row_idx = torch.arange(t).unsqueeze(1).repeat(1, t) + mask = (row_idx < col_idx).to(input_t.dtype) * large_negative_number + return ( + mask.unsqueeze(0).unsqueeze(0).to(input_t.device) + ) # Equivalent to jnp.newaxis + + +def merge_masks(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """Merges 2 masks. + + logscale mask is expected but 0/1 mask is also fine. + + Args: + a: torch.Tensor of shape [1|B, 1, 1|T, S]. + b: torch.Tensor of shape [1|B, 1, 1|T, S]. + + Returns: + torch.Tensor of shape [1|B, 1, 1|T, S]. + """ + + def expand_t(key_mask): + query_mask = key_mask.transpose(-1, -2) # Equivalent of jnp.transpose + return torch.minimum(query_mask, key_mask) + + if a.shape[2] != b.shape[2]: + if a.shape[2] == 1: + a = expand_t(a) + else: + assert b.shape[2] == 1 + b = expand_t(b) + + assert a.shape[1:] == b.shape[1:], f"a.shape={a.shape}, b.shape={b.shape}." + return torch.minimum(a, b) # Element-wise minimum, similar to jnp.minimum + + +class ResidualBlock(nn.Module): + """TimesFM residual block.""" + + def __init__( + self, + input_dims, + hidden_dims, + output_dims, + ): + super(ResidualBlock, self).__init__() + self.input_dims = input_dims + self.hidden_dims = hidden_dims + self.output_dims = output_dims + + # Hidden Layer + self.hidden_layer = nn.Sequential( + nn.Linear(input_dims, hidden_dims), + nn.SiLU(), + ) + + # Output Layer + self.output_layer = nn.Linear(hidden_dims, output_dims) + # Residual Layer + self.residual_layer = nn.Linear(input_dims, output_dims) + + def forward(self, x): + hidden = self.hidden_layer(x) + output = self.output_layer(hidden) + residual = self.residual_layer(x) + return output + residual + + +class RMSNorm(torch.nn.Module): + """Pax rms norm in pytorch.""" + + def __init__( + self, + dim: int, + eps: float = 1e-6, + add_unit_offset: bool = False, + ): + super().__init__() + self.eps = eps + self.add_unit_offset = add_unit_offset + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + if self.add_unit_offset: + output = output * (1 + self.weight.float()) + else: + output = output * self.weight.float() + return output.type_as(x) + + +class TransformerMLP(nn.Module): + """Pax transformer MLP in pytorch.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + ): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size) + self.down_proj = nn.Linear(intermediate_size, hidden_size) + self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6) + + def forward(self, x, paddings=None): + gate_inp = self.layer_norm(x) + gate = self.gate_proj(gate_inp) + gate = F.relu(gate) + outputs = self.down_proj(gate) + if paddings is not None: + outputs = outputs * (1.0 - paddings[:, :, None]) + return outputs + x + + +class TimesFMAttention(nn.Module): + """Implements the attention used in TimesFM.""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + ): + super().__init__() + + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.hidden_size = hidden_size + self.head_dim = head_dim + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = nn.Parameter( + torch.empty((self.head_dim,), dtype=torch.float32), + ) + + self.qkv_proj = nn.Linear( + self.hidden_size, + (self.num_heads + 2 * self.num_kv_heads) * self.head_dim, + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size) + + def _per_dim_scaling(self, query: torch.Tensor) -> torch.Tensor: + # [batch_size, n_local_heads, input_len, head_dim] + r_softplus_0 = 1.442695041 + softplus_func = torch.nn.Softplus() + scale = r_softplus_0 / math.sqrt(self.head_dim) + scale = scale * softplus_func(self.scaling) + return query * scale[None, None, None, :] + + def forward( + self, + hidden_states: torch.Tensor, + mask: torch.Tensor, + kv_write_indices: torch.Tensor | None = None, + kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + hidden_states_shape = hidden_states.shape + assert len(hidden_states_shape) == 3 + + batch_size, input_len, _ = hidden_states_shape + + qkv = self.qkv_proj(hidden_states) + xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + xq = xq.view(batch_size, -1, self.num_heads, self.head_dim) + xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim) + xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim) + xq = self._per_dim_scaling(xq) + + # Write new kv cache. + # [batch_size, input_len, n_local_kv_heads, head_dim] + if kv_cache is not None and kv_write_indices is not None: + k_cache, v_cache = kv_cache + k_cache.index_copy_(1, kv_write_indices, xk) + v_cache.index_copy_(1, kv_write_indices, xv) + + key = k_cache + value = v_cache + else: + key = xk + value = xv + if self.num_kv_heads != self.num_heads: + # [batch_size, max_seq_len, n_local_heads, head_dim] + key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) + value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2) + + # [batch_size, n_local_heads, input_len, head_dim] + q = xq.transpose(1, 2) + # [batch_size, n_local_heads, max_seq_len, head_dim] + k = key.transpose(1, 2) + v = value.transpose(1, 2) + + # [batch_size, n_local_heads, input_len, max_seq_len] + scores = torch.matmul(q, k.transpose(2, 3)) + scores = scores + mask + scores = F.softmax(scores.float(), dim=-1).type_as(q) + + # [batch_size, n_local_heads, input_len, head_dim] + output = torch.matmul(scores, v) + # return scores, output.transpose(1, 2).contiguous() + + # [batch_size, input_len, hidden_dim] + output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) + output = self.o_proj(output) + return scores, output + + +class TimesFMDecoderLayer(nn.Module): + """Transformer layer.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + rms_norm_eps: float = 1e-6, + ): + super().__init__() + self.self_attn = TimesFMAttention( + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + ) + self.mlp = TransformerMLP( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + ) + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + mask: torch.Tensor, + paddings: torch.Tensor, + kv_write_indices: torch.Tensor | None = None, + kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + scores, hidden_states = self.self_attn( + hidden_states=hidden_states, + mask=mask, + kv_write_indices=kv_write_indices, + kv_cache=kv_cache, + ) + hidden_states = residual + hidden_states + + # MLP + hidden_states = self.mlp(hidden_states, paddings=paddings) + + return scores, hidden_states + + +class StackedDecoder(nn.Module): + """Stacked transformer layer.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + num_layers: int, + rms_norm_eps: float = 1e-6, + ): + super().__init__() + + self.layers = nn.ModuleList() + for _ in range(num_layers): + self.layers.append( + TimesFMDecoderLayer( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + rms_norm_eps=rms_norm_eps, + ) + ) + + def forward( + self, + hidden_states: torch.Tensor, + paddings: torch.Tensor, + kv_write_indices: torch.Tensor | None = None, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] | None = None, + ) -> torch.Tensor: + padding_mask = convert_paddings_to_mask(paddings, hidden_states.dtype) + atten_mask = causal_mask(hidden_states) + mask = merge_masks(padding_mask, atten_mask) + for i in range(len(self.layers)): + layer = self.layers[i] + kv_cache = kv_caches[i] if kv_caches is not None else None + _, hidden_states = layer( + hidden_states=hidden_states, + mask=mask, + paddings=paddings, + kv_write_indices=kv_write_indices, + kv_cache=kv_cache, + ) + return hidden_states + + +class PositionalEmbedding(torch.nn.Module): + """Generates position embedding for a given 1-d sequence. + + Attributes: + min_timescale: Start of the geometric index. Determines the periodicity of + the added signal. + max_timescale: End of the geometric index. Determines the frequency of the + added signal. + embedding_dims: Dimension of the embedding to be generated. + """ + + def __init__( + self, + embedding_dims: int, + min_timescale: int = 1, + max_timescale: int = 10_000, + ) -> None: + super().__init__() + self.min_timescale = min_timescale + self.max_timescale = max_timescale + self.embedding_dims = embedding_dims + + def forward(self, seq_length=None, position=None): + """Generates a Tensor of sinusoids with different frequencies. + + Args: + seq_length: an optional Python int defining the output sequence length. + if the `position` argument is specified. + position: [B, seq_length], optional position for each token in the + sequence, only required when the sequence is packed. + + Returns: + [B, seqlen, D] if `position` is specified, else [1, seqlen, D] + """ + if position is None: + assert seq_length is not None + # [1, seqlen] + position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0) + else: + assert position.ndim == 2, position.shape + + num_timescales = self.embedding_dims // 2 + log_timescale_increment = math.log( + float(self.max_timescale) / float(self.min_timescale) + ) / max(num_timescales - 1, 1) + inv_timescales = self.min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(2) * inv_timescales.unsqueeze(0).unsqueeze(0) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) + # Padding to ensure correct embedding dimension + signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) + return signal + + +class PatchedTimeSeriesDecoder(nn.Module): + """Patched time-series decoder.""" + + def __init__(self, config: TimesFMConfig): + super().__init__() + self.config = config + self.input_ff_layer = ResidualBlock( + input_dims=2 * config.patch_len, + output_dims=config.model_dim, + hidden_dims=config.model_dim, + ) + self.freq_emb = nn.Embedding(num_embeddings=3, embedding_dim=config.model_dim) + self.horizon_ff_layer = ResidualBlock( + input_dims=config.model_dim, + output_dims=config.horizon_len * (1 + len(config.quantiles)), + hidden_dims=config.model_dim, + ) + self.stacked_transformer = StackedDecoder( + hidden_size=self.config.model_dim, + intermediate_size=self.config.model_dim, + num_heads=self.config.num_heads, + num_kv_heads=self.config.num_heads, + head_dim=self.config.head_dim, + num_layers=self.config.num_layers, + rms_norm_eps=self.config.rms_norm_eps, + ) + if self.config.use_positional_embedding: + self.position_emb = PositionalEmbedding(self.config.model_dim) + + def _forward_transform( + self, inputs: torch.Tensor, patched_pads: torch.Tensor + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """Input is of shape [B, N, P].""" + mu, sigma = _masked_mean_std(inputs, patched_pads) + sigma = torch.where( + sigma < self.config.tolerance, + torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device), + sigma, + ) + + # Normalize each patch + outputs = (inputs - mu[:, None, None]) / sigma[:, None, None] + outputs = torch.where( + torch.abs(inputs - self.config.pad_val) < self.config.tolerance, + torch.tensor( + self.config.pad_val, dtype=outputs.dtype, device=outputs.device + ), + outputs, + ) + return outputs, (mu, sigma) + + def _reverse_transform( + self, outputs: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor] + ) -> torch.Tensor: + """Output is of shape [B, N, P, Q].""" + mu, sigma = stats + return outputs * sigma[:, None, None, None] + mu[:, None, None, None] + + def _preprocess_input( + self, + input_ts: torch.Tensor, + input_padding: torch.Tensor, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + tuple[torch.Tensor, torch.Tensor] | None, + torch.Tensor, + ]: + """Preprocess input for stacked transformer.""" + + # Reshape into patches (using view for efficiency) + bsize = input_ts.shape[0] + patched_inputs = input_ts.view(bsize, -1, self.config.patch_len) + patched_pads = input_padding.view(bsize, -1, self.config.patch_len) + + patched_inputs = torch.where( + torch.abs(patched_pads - 1.0) < self.config.tolerance, + torch.tensor(0.0, dtype=patched_inputs.dtype, device=patched_inputs.device), + patched_inputs, + ) + patched_pads = torch.where( + torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance, + torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device), + patched_pads, + ) + patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads) + + # B x N x D + patched_inputs = patched_inputs * (1.0 - patched_pads) + concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1) + model_input = self.input_ff_layer(concat_inputs) + + # A patch should not be padded even if there is at least one zero. + patched_padding = torch.min(patched_pads, dim=-1)[ + 0 + ] # Get the values from the min result + if self.config.use_positional_embedding: + pos_emb = self.position_emb(model_input.shape[1]).to(model_input.device) + pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0) + pos_emb = _shift_padded_seq(patched_padding, pos_emb) + model_input += pos_emb + + return model_input, patched_padding, stats, patched_inputs + + def _postprocess_output( + self, + model_output: torch.Tensor, + num_outputs: int, + stats: tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + """Postprocess output of stacked transformer.""" + + # B x N x (H.Q) + output_ts = self.horizon_ff_layer(model_output) + + # Reshape using view + b, n, _ = output_ts.shape + output_ts = output_ts.view(b, n, self.config.horizon_len, num_outputs) + + return self._reverse_transform(output_ts, stats) + + def forward( + self, + input_ts: torch.Tensor, + input_padding: torch.LongTensor, + freq: torch.Tensor, + ) -> torch.Tensor: + num_outputs = len(self.config.quantiles) + 1 + model_input, patched_padding, stats, _ = self._preprocess_input( + input_ts=input_ts, + input_padding=input_padding, + ) + f_emb = self.freq_emb(freq) # B x 1 x D + model_input += f_emb + model_output = self.stacked_transformer(model_input, patched_padding) + + output_ts = self._postprocess_output(model_output, num_outputs, stats) + return output_ts + + def decode( + self, + input_ts: torch.Tensor, + paddings: torch.Tensor, + freq: torch.LongTensor, + horizon_len: int, + output_patch_len: int | None = None, + max_len: int = 512, + return_forecast_on_context: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Auto-regressive decoding without caching. + + Args: + input_ts: input time-series and paddings. Time-series shape B x C. + paddings: padding shape B x (C + H) where H is the prediction length. + freq: frequency shape B x 1 + horizon_len: prediction length. + output_patch_len: output length to be fetched from one step of + auto-regressive decoding. + max_len: maximum training context length. + return_forecast_on_context: whether to return the model forecast on the + context except the first input patch. + + Returns: + Tuple of two forecasting results: + - Point (mean) output predictions as a tensor with shape B x H'. + - Full predictions (mean and quantiles) as a tensor with shape + B x H' x (1 + # quantiles). + In particular, if return_forecast_on_context is True, H' is H plus + the forecastable context length, i.e. context_len - (first) patch_len. + """ + final_out = input_ts + context_len = final_out.shape[1] + full_outputs = [] + if paddings.shape[1] != final_out.shape[1] + horizon_len: + raise ValueError( + "Length of paddings must match length of input + horizon_len:" + f" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}" + ) + if output_patch_len is None: + output_patch_len = self.config.horizon_len + num_decode_patches = (horizon_len + output_patch_len - 1) // output_patch_len + for step_index in range(num_decode_patches): + current_padding = paddings[:, 0 : final_out.shape[1]] + input_ts = final_out[:, -max_len:] + input_padding = current_padding[:, -max_len:] + fprop_outputs = self(input_ts, input_padding, freq) + if return_forecast_on_context and step_index == 0: + # For the first decodings step, collect the model forecast on the + # context except the unavailable first input batch forecast. + new_full_ts = fprop_outputs[:, :-1, : self.config.patch_len, :] + new_full_ts = fprop_outputs.view( + new_full_ts.size(0), -1, new_full_ts.size(3) + ) + + full_outputs.append(new_full_ts) + + # (full batch, last patch, output_patch_len, index of mean forecast = 0) + new_ts = fprop_outputs[:, -1, :output_patch_len, 0] + new_full_ts = fprop_outputs[:, -1, :output_patch_len, :] + # (full batch, last patch, output_patch_len, all output indices) + full_outputs.append(new_full_ts) + final_out = torch.concatenate([final_out, new_ts], axis=-1) + + if return_forecast_on_context: + # `full_outputs` indexing starts at after the first input patch. + full_outputs = torch.concatenate(full_outputs, axis=1)[ + :, : (context_len - self.config.patch_len + horizon_len), : + ] + else: + # `full_outputs` indexing starts at the forecast horizon. + full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0:horizon_len, :] + + return (full_outputs[:, :, 0], full_outputs) diff --git a/src/transformers/models/timesfm/timesfm_base.py b/src/transformers/models/timesfm/timesfm_base.py new file mode 100644 index 000000000000..c5f113ee6000 --- /dev/null +++ b/src/transformers/models/timesfm/timesfm_base.py @@ -0,0 +1,572 @@ +# coding=utf-8 +# Copyright 2024 Google LLC and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base class for TimesFM inference. This will be common to PAX and Pytorch.""" + +import collections +import dataclasses +import logging +import multiprocessing +from typing import Any, Literal, Sequence + +import numpy as np +import pandas as pd + +from utilsforecast.processing import make_future_dataframe +from configuration_timesfm import TimesFMConfig +import xreg_lib + +Category = xreg_lib.Category +XRegMode = xreg_lib.XRegMode + +_TOL = 1e-6 +DEFAULT_QUANTILES = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9) + + +def process_group(key, group, value_name, forecast_context_len): + group = group.tail(forecast_context_len) + return np.array(group[value_name], dtype=np.float32), key + + +def moving_average(arr, window_size): + """Calculates the moving average using NumPy's convolution function.""" + # Pad with zeros to handle initial window positions + arr_padded = np.pad(arr, (window_size - 1, 0), "constant") + smoothed_arr = np.convolve(arr_padded, np.ones(window_size), "valid") / window_size + return [smoothed_arr, arr - smoothed_arr] + + +def freq_map(freq: str): + """Returns the frequency map for the given frequency string.""" + freq = str.upper(freq) + if ( + freq.endswith("H") + or freq.endswith("T") + or freq.endswith("MIN") + or freq.endswith("D") + or freq.endswith("B") + or freq.endswith("U") + ): + return 0 + elif freq.endswith(("W", "M", "MS")): + return 1 + elif freq.endswith("Y") or freq.endswith("Q"): + return 2 + else: + raise ValueError(f"Invalid frequency: {freq}") + + +# Per time series normalization: forward. +def normalize(batch): + stats = [(np.mean(x), np.where((w := np.std(x)) > _TOL, w, 1.0)) for x in batch] + new_batch = [(x - stat[0]) / stat[1] for x, stat in zip(batch, stats)] + return new_batch, stats + + +# Per time series normalization: inverse. +def renormalize(batch, stats): + return [x * stat[1] + stat[0] for x, stat in zip(batch, stats)] + + +@dataclasses.dataclass(kw_only=True) +class TimesFmCheckpoint: + """Checkpoint used to initialize a TimesFM model for inference. + + Attributes: + version: Version of the checkpoint, e.g. "jax", "torch", "tensorflow", etc. + The factory will create the corresponding TimesFm inference class based on + this version. + path: Path to the checkpoint. + type: If provided, type of the checkpoint used by the specific checkpoint + loader per version. + step: If provided, step of the checkpoint. + """ + + version: str = "jax" + path: str | None = None + huggingface_repo_id: str | None = None + type: Any = None + step: int | None = None + + +class TimesFmBase: + """Base TimesFM forecast API for inference. + + This class is the scaffolding for calling TimesFM forecast. To properly use: + 1. Create an instance with the correct hyperparameters of a TimesFM model. + 2. Call `load_from_checkpoint` to load a compatible checkpoint. + 3. Call `forecast` for inference. + """ + + def _logging(self, s): + print(s) + + def __post_init__(self) -> None: + """Additional initialization for subclasses before checkpoint loading.""" + pass + + def __init__(self, hparams: TimesFMConfig, checkpoint: TimesFmCheckpoint) -> None: + """Initializes the TimesFM forecast API. + + Args: + hparams: Hyperparameters of the model. + checkpoint: Checkpoint to load. Notice `checkpoint.version` will decide + which TimesFM version to use. + """ + self.hparams = hparams + + # Expand hparams for conciseness within the model code. + self.context_len = hparams.context_len + self.horizon_len = hparams.horizon_len + self.input_patch_len = hparams.patch_len + self.output_patch_len = hparams.horizon_len + self.num_layers = hparams.num_layers + self.model_dims = hparams.model_dim + self.backend = hparams.backend + self.quantiles = hparams.quantiles + self.num_heads = hparams.num_heads + + # Rewrite these values in __post_init__ for SPMD. + self.num_cores = 1 + self.per_core_batch_size = hparams.per_core_batch_size + self.global_batch_size = hparams.per_core_batch_size + + self._horizon_start = self.context_len - self.input_patch_len + self.__post_init__() + self.load_from_checkpoint(checkpoint) + + def load_from_checkpoint(self, checkpoint: TimesFmCheckpoint) -> None: + """Loads a checkpoint and compiles the decoder.""" + raise NotImplementedError("`load_from_checkpoint` is not implemented.") + + def _preprocess( + self, inputs: Sequence[np.array], freq: Sequence[int] + ) -> tuple[np.array, np.array, int]: + """Formats and pads raw inputs to feed into the model. + + This function both pads each time series to match the context length, and + pads the inputs to meet the SPMD shape requirement. + + Args: + inputs: A list of 1d JTensors. Each JTensor is the context time series of + a single forecast task. + freq: list of frequencies + + Returns: + A tuple of: + - the padded input time series to meet the model required context. + - the padding indicator. + - the number of padded examples for SPMD so that each core has the same + number (a multiple of `batch_size`) of examples. + """ + + input_ts, input_padding, inp_freq = [], [], [] + + pmap_pad = ( + (len(inputs) - 1) // self.global_batch_size + 1 + ) * self.global_batch_size - len(inputs) + + for i, ts in enumerate(inputs): + input_len = ts.shape[0] + padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float) + if input_len < self.context_len: + num_front_pad = self.context_len - input_len + ts = np.concatenate( + [np.zeros(shape=(num_front_pad,), dtype=float), ts], axis=0 + ) + padding = np.concatenate( + [np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0 + ) + elif input_len > self.context_len: + ts = ts[-self.context_len :] + padding = padding[-(self.context_len + self.horizon_len) :] + + input_ts.append(ts) + input_padding.append(padding) + inp_freq.append(freq[i]) + + # Padding the remainder batch. + for _ in range(pmap_pad): + input_ts.append(input_ts[-1]) + input_padding.append(input_padding[-1]) + inp_freq.append(inp_freq[-1]) + + return ( + np.stack(input_ts, axis=0), + np.stack(input_padding, axis=0), + np.array(inp_freq).astype(np.int32).reshape(-1, 1), + pmap_pad, + ) + + def forecast( + self, + inputs: Sequence[Any], + freq: Sequence[int] | None = None, + window_size: int | None = None, + forecast_context_len: int | None = None, + return_forecast_on_context: bool = False, + truncate_negative: bool = False, + ) -> tuple[np.array, np.array]: + """Forecasts on a list of time series. + + Args: + inputs: list of time series forecast contexts. Each context time series + should be in a format convertible to JTensor by `jnp.array`. + freq: frequency of each context time series. 0 for high frequency + (default), 1 for medium, and 2 for low. Notice this is different from + the `freq` required by `forecast_on_df`. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + forecast_context_len: optional max context length. + return_forecast_on_context: True to return the forecast on the context + when available, i.e. after the first input patch. + truncate_negative: truncate to only non-negative values if all the contexts + have non-negative values. + + Returns: + A tuple for JTensors: + - the mean forecast of size (# inputs, # forecast horizon), + - the full forecast (mean + quantiles) of size + (# inputs, # forecast horizon, 1 + # quantiles). + + Raises: + ValueError: If the checkpoint is not properly loaded. + """ + raise NotImplementedError("`forecast` is not implemented.") + + def forecast_with_covariates( + self, + inputs: list[Sequence[float]], + dynamic_numerical_covariates: ( + dict[str, Sequence[Sequence[float]]] | None + ) = None, + dynamic_categorical_covariates: ( + dict[str, Sequence[Sequence[Category]]] | None + ) = None, + static_numerical_covariates: dict[str, Sequence[float]] | None = None, + static_categorical_covariates: dict[str, Sequence[Category]] | None = None, + freq: Sequence[int] | None = None, + window_size: int | None = None, + forecast_context_len: int | None = None, + xreg_mode: XRegMode = "xreg + timesfm", + normalize_xreg_target_per_input: bool = True, + ridge: float = 0.0, + max_rows_per_col: int = 0, + force_on_cpu: bool = False, + ): + """Forecasts on a list of time series with covariates. + + To optimize inference speed, avoid string valued categorical covariates. + + Args: + inputs: A list of time series forecast contexts. Each context time series + should be in a format convertible to JTensor by `jnp.array`. + dynamic_numerical_covariates: A dict of dynamic numerical covariates. + dynamic_categorical_covariates: A dict of dynamic categorical covariates. + static_numerical_covariates: A dict of static numerical covariates. + static_categorical_covariates: A dict of static categorical covariates. + freq: frequency of each context time series. 0 for high frequency + (default), 1 for medium, and 2 for low. Notice this is different from + the `freq` required by `forecast_on_df`. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + forecast_context_len: optional max context length. + xreg_mode: one of "xreg + timesfm" or "timesfm + xreg". "xreg + timesfm" + fits a model on the residuals of the TimesFM forecast. "timesfm + xreg" + fits a model on the targets then forecasts on the residuals via TimesFM. + normalize_xreg_target_per_input: whether to normalize the xreg target per + input in the given batch. + ridge: ridge penalty for the linear model. + max_rows_per_col: max number of rows per column for the linear model. + force_on_cpu: whether to force running on cpu for the linear model. + + Returns: + A tuple of two lists. The first is the outputs of the model. The second is + the outputs of the xreg. + """ + + # Verify and bookkeep covariates. + if not ( + dynamic_numerical_covariates + or dynamic_categorical_covariates + or static_numerical_covariates + or static_categorical_covariates + ): + raise ValueError( + "At least one of dynamic_numerical_covariates," + " dynamic_categorical_covariates, static_numerical_covariates," + " static_categorical_covariates must be set." + ) + + # Track the lengths of (1) each input, (2) the part that can be used in the + # linear model, and (3) the horizon. + input_lens, train_lens, test_lens = [], [], [] + + for i, input_ts in enumerate(inputs): + input_len = len(input_ts) + input_lens.append(input_len) + + if xreg_mode == "timesfm + xreg": + # For fitting residuals, no TimesFM forecast on the first patch. + train_lens.append(max(0, input_len - self.input_patch_len)) + elif xreg_mode == "xreg + timesfm": + train_lens.append(input_len) + else: + raise ValueError(f"Unsupported mode: {xreg_mode}") + + if dynamic_numerical_covariates: + test_lens.append( + len(list(dynamic_numerical_covariates.values())[0][i]) - input_len + ) + elif dynamic_categorical_covariates: + test_lens.append( + len(list(dynamic_categorical_covariates.values())[0][i]) - input_len + ) + else: + test_lens.append(self.horizon_len) + + if test_lens[-1] > self.horizon_len: + raise ValueError( + "Forecast requested longer horizon than the model definition " + f"supports: {test_lens[-1]} vs {self.horizon_len}." + ) + + # Prepare the covariates into train and test. + train_dynamic_numerical_covariates = collections.defaultdict(list) + test_dynamic_numerical_covariates = collections.defaultdict(list) + train_dynamic_categorical_covariates = collections.defaultdict(list) + test_dynamic_categorical_covariates = collections.defaultdict(list) + for covariates, train_covariates, test_covariates in ( + ( + dynamic_numerical_covariates, + train_dynamic_numerical_covariates, + test_dynamic_numerical_covariates, + ), + ( + dynamic_categorical_covariates, + train_dynamic_categorical_covariates, + test_dynamic_categorical_covariates, + ), + ): + if not covariates: + continue + for covariate_name, covariate_values in covariates.items(): + for input_len, train_len, covariate_value in zip( + input_lens, train_lens, covariate_values + ): + train_covariates[covariate_name].append( + covariate_value[(input_len - train_len) : input_len] + ) + test_covariates[covariate_name].append(covariate_value[input_len:]) + + # Fit models. + if xreg_mode == "timesfm + xreg": + # Forecast via TimesFM then fit a model on the residuals. + mean_outputs, _ = self.forecast( + inputs, + freq, + window_size, + forecast_context_len, + return_forecast_on_context=True, + ) + targets = [ + ( + np.array(input_ts)[-train_len:] + - mean_output[ + (self._horizon_start - train_len) : self._horizon_start + ] + ) + for input_ts, mean_output, train_len in zip( + inputs, mean_outputs, train_lens + ) + ] + per_instance_stats = None + if normalize_xreg_target_per_input: + targets, per_instance_stats = normalize(targets) + xregs = xreg_lib.BatchedInContextXRegLinear( + targets=targets, + train_lens=train_lens, + test_lens=test_lens, + train_dynamic_numerical_covariates=train_dynamic_numerical_covariates, + test_dynamic_numerical_covariates=test_dynamic_numerical_covariates, + train_dynamic_categorical_covariates=train_dynamic_categorical_covariates, + test_dynamic_categorical_covariates=test_dynamic_categorical_covariates, + static_numerical_covariates=static_numerical_covariates, + static_categorical_covariates=static_categorical_covariates, + ).fit( + ridge=ridge, + one_hot_encoder_drop=None if ridge > 0 else "first", + max_rows_per_col=max_rows_per_col, + force_on_cpu=force_on_cpu, + debug_info=False, + assert_covariates=True, + assert_covariate_shapes=True, + ) + if normalize_xreg_target_per_input: + xregs = renormalize(xregs, per_instance_stats) + outputs = [ + ( + mean_output[self._horizon_start : (self._horizon_start + test_len)] + + xreg + ) + for mean_output, test_len, xreg in zip(mean_outputs, test_lens, xregs) + ] + + else: + # Fit a model on the targets then forecast on the residuals via TimesFM. + targets = [ + np.array(input_ts)[-train_len:] + for input_ts, train_len in zip(inputs, train_lens) + ] + per_instance_stats = None + if normalize_xreg_target_per_input: + targets, per_instance_stats = normalize(targets) + xregs, xregs_on_context, _, _, _ = xreg_lib.BatchedInContextXRegLinear( + targets=targets, + train_lens=train_lens, + test_lens=test_lens, + train_dynamic_numerical_covariates=train_dynamic_numerical_covariates, + test_dynamic_numerical_covariates=test_dynamic_numerical_covariates, + train_dynamic_categorical_covariates=train_dynamic_categorical_covariates, + test_dynamic_categorical_covariates=test_dynamic_categorical_covariates, + static_numerical_covariates=static_numerical_covariates, + static_categorical_covariates=static_categorical_covariates, + ).fit( + ridge=ridge, + one_hot_encoder_drop=None if ridge > 0 else "first", + max_rows_per_col=max_rows_per_col, + force_on_cpu=force_on_cpu, + debug_info=True, + assert_covariates=True, + assert_covariate_shapes=True, + ) + mean_outputs, _ = self.forecast( + [ + target - xreg_on_context + for target, xreg_on_context in zip(targets, xregs_on_context) + ], + freq, + window_size, + forecast_context_len, + return_forecast_on_context=True, + ) + outputs = [ + ( + mean_output[self._horizon_start : (self._horizon_start + test_len)] + + xreg + ) + for mean_output, test_len, xreg in zip(mean_outputs, test_lens, xregs) + ] + if normalize_xreg_target_per_input: + outputs = renormalize(outputs, per_instance_stats) + + return outputs, xregs + + def forecast_on_df( + self, + inputs: pd.DataFrame, + freq: str, + forecast_context_len: int = 0, + value_name: str = "values", + model_name: str = "timesfm", + window_size: int | None = None, + num_jobs: int = 1, + verbose: bool = True, + ) -> pd.DataFrame: + """Forecasts on a list of time series. + + Args: + inputs: A pd.DataFrame of all time series. The dataframe should have a + `unique_id` column for identifying the time series, a `ds` column for + timestamps and a value column for the time series values. + freq: string valued `freq` of data. Notice this is different from the + `freq` required by `forecast`. See `freq_map` for allowed values. + forecast_context_len: If provided none zero, we take the last + `forecast_context_len` time-points from each series as the forecast + context instead of the `context_len` set by the model. + value_name: The name of the value column. + model_name: name of the model to be written into future df. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + num_jobs: number of parallel processes to use for dataframe processing. + verbose: output model states in terminal. + + Returns: + Future forecasts dataframe. + """ + if not ( + "unique_id" in inputs.columns + and "ds" in inputs.columns + and value_name in inputs.columns + ): + raise ValueError( + f"DataFrame must have unique_id, ds and {value_name} columns." + ) + if not forecast_context_len: + forecast_context_len = self.context_len + logging.info("Preprocessing dataframe.") + df_sorted = inputs.sort_values(by=["unique_id", "ds"]) + new_inputs = [] + uids = [] + if num_jobs == 1: + if verbose: + print("Processing dataframe with single process.") + for key, group in df_sorted.groupby("unique_id"): + inp, uid = process_group( + key, + group, + value_name, + forecast_context_len, + ) + new_inputs.append(inp) + uids.append(uid) + else: + if num_jobs == -1: + num_jobs = multiprocessing.cpu_count() + if verbose: + print("Processing dataframe with multiple processes.") + with multiprocessing.Pool(processes=num_jobs) as pool: + results = pool.starmap( + process_group, + [ + (key, group, value_name, forecast_context_len) + for key, group in df_sorted.groupby("unique_id") + ], + ) + new_inputs, uids = zip(*results) + if verbose: + print("Finished preprocessing dataframe.") + freq_inps = [freq_map(freq)] * len(new_inputs) + _, full_forecast = self.forecast( + new_inputs, freq=freq_inps, window_size=window_size + ) + if verbose: + print("Finished forecasting.") + fcst_df = make_future_dataframe( + uids=uids, + last_times=df_sorted.groupby("unique_id")["ds"].tail(1), + h=self.horizon_len, + freq=freq, + ) + fcst_df[model_name] = full_forecast[:, 0 : self.horizon_len, 0].reshape(-1, 1) + + for i, q in enumerate(self.quantiles): + q_col = f"{model_name}-q-{q}" + fcst_df[q_col] = full_forecast[:, 0 : self.horizon_len, 1 + i].reshape( + -1, 1 + ) + if q == 0.5: + fcst_df[model_name] = fcst_df[q_col] + logging.info("Finished creating output dataframe.") + return fcst_df diff --git a/src/transformers/models/timesfm/xreg_lib.py b/src/transformers/models/timesfm/xreg_lib.py new file mode 100644 index 000000000000..1c7d253990ca --- /dev/null +++ b/src/transformers/models/timesfm/xreg_lib.py @@ -0,0 +1,520 @@ +# coding=utf-8 +# Copyright 2024 Google LLC and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helper functions for in-context covariates and regression.""" + +import itertools +import math +from typing import Any, Iterable, Literal, Mapping, Sequence + +import jax +import jax.numpy as jnp +import numpy as np +from sklearn import preprocessing + +Category = int | str + +_TOL = 1e-6 +XRegMode = Literal["timesfm + xreg", "xreg + timesfm"] + + +def _unnest(nested: Sequence[Sequence[Any]]) -> np.ndarray: + return np.array(list(itertools.chain.from_iterable(nested))) + + +def _repeat(elements: Iterable[Any], counts: Iterable[int]) -> np.ndarray: + return np.array( + list(itertools.chain.from_iterable(map(itertools.repeat, elements, counts))) + ) + + +def _to_padded_jax_array(x: np.ndarray) -> jax.Array: + if x.ndim == 1: + (i,) = x.shape + di = 2 ** math.ceil(math.log2(i)) - i + return jnp.pad(x, ((0, di),), mode="constant", constant_values=0.0) + elif x.ndim == 2: + i, j = x.shape + di = 2 ** math.ceil(math.log2(i)) - i + dj = 2 ** math.ceil(math.log2(j)) - j + return jnp.pad(x, ((0, di), (0, dj)), mode="constant", constant_values=0.0) + else: + raise ValueError(f"Unsupported array shape: {x.shape}") + + +class BatchedInContextXRegBase: + """Helper class for in-context regression covariate formatting. + + Attributes: + targets: List of targets (responses) of the in-context regression. + train_lens: List of lengths of each target vector from the context. + test_lens: List of lengths of each forecast horizon. + train_dynamic_numerical_covariates: Dict of covariate names mapping to the + dynamic numerical covariates of each forecast task on the context. Their + lengths should match the corresponding lengths in `train_lens`. + train_dynamic_categorical_covariates: Dict of covariate names mapping to the + dynamic categorical covariates of each forecast task on the context. Their + lengths should match the corresponding lengths in `train_lens`. + test_dynamic_numerical_covariates: Dict of covariate names mapping to the + dynamic numerical covariates of each forecast task on the horizon. Their + lengths should match the corresponding lengths in `test_lens`. + test_dynamic_categorical_covariates: Dict of covariate names mapping to the + dynamic categorical covariates of each forecast task on the horizon. Their + lengths should match the corresponding lengths in `test_lens`. + static_numerical_covariates: Dict of covariate names mapping to the static + numerical covariates of each forecast task. + static_categorical_covariates: Dict of covariate names mapping to the static + categorical covariates of each forecast task. + """ + + def __init__( + self, + targets: Sequence[Sequence[float]], + train_lens: Sequence[int], + test_lens: Sequence[int], + train_dynamic_numerical_covariates: ( + Mapping[str, Sequence[Sequence[float]]] | None + ) = None, + train_dynamic_categorical_covariates: ( + Mapping[str, Sequence[Sequence[Category]]] | None + ) = None, + test_dynamic_numerical_covariates: ( + Mapping[str, Sequence[Sequence[float]]] | None + ) = None, + test_dynamic_categorical_covariates: ( + Mapping[str, Sequence[Sequence[Category]]] | None + ) = None, + static_numerical_covariates: Mapping[str, Sequence[float]] | None = None, + static_categorical_covariates: Mapping[str, Sequence[Category]] | None = None, + ) -> None: + """Initializes with the exogenous covariate inputs. + + Here we use model fitting language to refer to the context as 'train' and + the horizon as 'test'. We assume batched inputs. To properly format the + request: + + - `train_lens` represents the contexts in the batch. Targets and all train + dynamic covariates should have the same lengths as the corresponding + elements + in `train_lens`. Notice each `train_len` can be different from the exact + length of the corresponding context depending on how much of the context is + used for fitting the in-context model. + - `test_lens` represents the horizon lengths in the batch. All tesdt + dynamic + covariates should have the same lengths as the corresponding elements in + `test_lens`. + - Static covariates should be one for each input. + - For train and test dynamic covariates, they should have the same + covariate + names. + + Pass an empty dict {} for a covariate type if it is not present. + + Example: + Here is a set of valid inputs whose schema can be used for reference. + ``` + targets = [ + [0.0, 0.1, 0.2], + [0.0, 0.1, 0.2, 0.3], + ] # Two inputs in this batch. + train_lens = [3, 4] + test_lens = [2, 5] # Forecast horizons 2 and 5 respectively. + train_dynamic_numerical_covariates = { + "cov_1_dn": [[0.0, 0.5, 1.0], [0.0, 0.5, 1.0, 1.5]], + "cov_2_dn": [[0.0, 1.5, 1.0], [0.0, 1.5, 1.0, 2.5]], + } # Each train dynamic covariate has 3 and 4 elements respectively. + test_dynamic_numerical_covariates = { + "cov_1_dn": [[0.1, 0.6], [0.1, 0.6, 1.1, 1.6, 2.4]], + "cov_2_dn": [[0.1, 1.1], [0.1, 1.6, 1.1, 2.6, 10.0]], + } # Each test dynamic covariate has 2 and 5 elements respectively. + train_dynamic_categorical_covariates = { + "cov_1_dc": [[0, 1, 0], [0, 1, 2, 3]], + "cov_2_dc": [["good", "bad", "good"], ["good", "good", "bad", + "bad"]], + } + test_dynamic_categorical_covariates = { + "cov_1_dc": [[1, 0], [1, 0, 2, 3, 1]], + "cov_2_dc": [["bad", "good"], ["bad", "bad", "bad", "bad", "bad"]], + } + static_numerical_covariates = { + "cov_1_sn": [0.0, 3.0], + "cov_2_sn": [2.0, 1.0], + "cov_3_sn": [1.0, 2.0], + } # Each static covariate has 1 element for each input. + static_categorical_covariates = { + "cov_1_sc": ["apple", "orange"], + "cov_2_sc": [2, 3], + } + ``` + + Args: + targets: List of targets (responses) of the in-context regression. + train_lens: List of lengths of each target vector from the context. + test_lens: List of lengths of each forecast horizon. + train_dynamic_numerical_covariates: Dict of covariate names mapping to the + dynamic numerical covariates of each forecast task on the context. Their + lengths should match the corresponding lengths in `train_lens`. + train_dynamic_categorical_covariates: Dict of covariate names mapping to + the dynamic categorical covariates of each forecast task on the context. + Their lengths should match the corresponding lengths in `train_lens`. + test_dynamic_numerical_covariates: Dict of covariate names mapping to the + dynamic numerical covariates of each forecast task on the horizon. Their + lengths should match the corresponding lengths in `test_lens`. + test_dynamic_categorical_covariates: Dict of covariate names mapping to + the dynamic categorical covariates of each forecast task on the horizon. + Their lengths should match the corresponding lengths in `test_lens`. + static_numerical_covariates: Dict of covariate names mapping to the static + numerical covariates of each forecast task. + static_categorical_covariates: Dict of covariate names mapping to the + static categorical covariates of each forecast task. + """ + self.targets = targets + self.train_lens = train_lens + self.test_lens = test_lens + self.train_dynamic_numerical_covariates = ( + train_dynamic_numerical_covariates or {} + ) + self.train_dynamic_categorical_covariates = ( + train_dynamic_categorical_covariates or {} + ) + self.test_dynamic_numerical_covariates = test_dynamic_numerical_covariates or {} + self.test_dynamic_categorical_covariates = ( + test_dynamic_categorical_covariates or {} + ) + self.static_numerical_covariates = static_numerical_covariates or {} + self.static_categorical_covariates = static_categorical_covariates or {} + + def _assert_covariates(self, assert_covariate_shapes: bool = False) -> None: + """Verifies the validity of the covariate inputs.""" + + # Check presence. + if ( + self.train_dynamic_numerical_covariates + and not self.test_dynamic_numerical_covariates + ) or ( + not self.train_dynamic_numerical_covariates + and self.test_dynamic_numerical_covariates + ): + raise ValueError( + "train_dynamic_numerical_covariates and" + " test_dynamic_numerical_covariates must be both present or both" + " absent." + ) + + if ( + self.train_dynamic_categorical_covariates + and not self.test_dynamic_categorical_covariates + ) or ( + not self.train_dynamic_categorical_covariates + and self.test_dynamic_categorical_covariates + ): + raise ValueError( + "train_dynamic_categorical_covariates and" + " test_dynamic_categorical_covariates must be both present or both" + " absent." + ) + + # Check keys. + for dict_a, dict_b, dict_a_name, dict_b_name in ( + ( + self.train_dynamic_numerical_covariates, + self.test_dynamic_numerical_covariates, + "train_dynamic_numerical_covariates", + "test_dynamic_numerical_covariates", + ), + ( + self.train_dynamic_categorical_covariates, + self.test_dynamic_categorical_covariates, + "train_dynamic_categorical_covariates", + "test_dynamic_categorical_covariates", + ), + ): + if w := set(dict_a.keys()) - set(dict_b.keys()): + raise ValueError( + f"{dict_a_name} has keys not present in {dict_b_name}: {w}" + ) + if w := set(dict_b.keys()) - set(dict_a.keys()): + raise ValueError( + f"{dict_b_name} has keys not present in {dict_a_name}: {w}" + ) + + # Check shapes. + if assert_covariate_shapes: + if len(self.targets) != len(self.train_lens): + raise ValueError( + "targets and train_lens must have the same number of elements." + ) + + if len(self.train_lens) != len(self.test_lens): + raise ValueError( + "train_lens and test_lens must have the same number of elements." + ) + + for i, (target, train_len) in enumerate(zip(self.targets, self.train_lens)): + if len(target) != train_len: + raise ValueError( + f"targets[{i}] has length {len(target)} != expected {train_len}." + ) + + for key, values in self.static_numerical_covariates.items(): + if len(values) != len(self.train_lens): + raise ValueError( + f"static_numerical_covariates has key {key} with number of" + f" examples {len(values)} != expected {len(self.train_lens)}." + ) + + for key, values in self.static_categorical_covariates.items(): + if len(values) != len(self.train_lens): + raise ValueError( + f"static_categorical_covariates has key {key} with number of" + f" examples {len(values)} != expected {len(self.train_lens)}." + ) + + for lens, dict_cov, dict_cov_name in ( + ( + self.train_lens, + self.train_dynamic_numerical_covariates, + "train_dynamic_numerical_covariates", + ), + ( + self.train_lens, + self.train_dynamic_categorical_covariates, + "train_dynamic_categorical_covariates", + ), + ( + self.test_lens, + self.test_dynamic_numerical_covariates, + "test_dynamic_numerical_covariates", + ), + ( + self.test_lens, + self.test_dynamic_categorical_covariates, + "test_dynamic_categorical_covariates", + ), + ): + for key, cov_values in dict_cov.items(): + if len(cov_values) != len(lens): + raise ValueError( + f"{dict_cov_name} has key {key} with number of examples" + f" {len(cov_values)} != expected {len(lens)}." + ) + for i, cov_value in enumerate(cov_values): + if len(cov_value) != lens[i]: + raise ValueError( + f"{dict_cov_name} has key {key} with its {i}-th example" + f" length {len(cov_value)} != expected {lens[i]}." + ) + + def create_covariate_matrix( + self, + one_hot_encoder_drop: str | None = "first", + use_intercept: bool = True, + assert_covariates: bool = False, + assert_covariate_shapes: bool = False, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Creates target vector and covariate matrices for in context regression. + + Here we use model fitting language to refer to the context as 'train' and + the horizon as 'test'. + + Args: + one_hot_encoder_drop: Which drop strategy to use for the one hot encoder. + use_intercept: Whether to prepare an intercept (all 1) column in the + matrices. + assert_covariates: Whether to assert the validity of the covariate inputs. + assert_covariate_shapes: Whether to assert the shapes of the covariate + inputs when `assert_covariates` is True. + + Returns: + A tuple of the target vector, the covariate matrix for the context, and + the covariate matrix for the horizon. + """ + if assert_covariates: + self._assert_covariates(assert_covariate_shapes) + + x_train, x_test = [], [] + + # Numerical features. + for name in sorted(self.train_dynamic_numerical_covariates): + x_train.append( + _unnest(self.train_dynamic_numerical_covariates[name])[:, np.newaxis] + ) + x_test.append( + _unnest(self.test_dynamic_numerical_covariates[name])[:, np.newaxis] + ) + + for covs in self.static_numerical_covariates.values(): + x_train.append(_repeat(covs, self.train_lens)[:, np.newaxis]) + x_test.append(_repeat(covs, self.test_lens)[:, np.newaxis]) + + if x_train: + x_train = np.concatenate(x_train, axis=1) + x_test = np.concatenate(x_test, axis=1) + + # Normalize for robustness. + x_mean = np.mean(x_train, axis=0, keepdims=True) + x_std = np.where( + (w := np.std(x_train, axis=0, keepdims=True)) > _TOL, w, 1.0 + ) + x_train = [(x_train - x_mean) / x_std] + x_test = [(x_test - x_mean) / x_std] + + # Categorical features. Encode one by one. + one_hot_encoder = preprocessing.OneHotEncoder( + drop=one_hot_encoder_drop, + sparse_output=False, + handle_unknown="ignore", + ) + for name in sorted(self.train_dynamic_categorical_covariates.keys()): + ohe_train = _unnest(self.train_dynamic_categorical_covariates[name])[ + :, np.newaxis + ] + ohe_test = _unnest(self.test_dynamic_categorical_covariates[name])[ + :, np.newaxis + ] + x_train.append(np.array(one_hot_encoder.fit_transform(ohe_train))) + x_test.append(np.array(one_hot_encoder.transform(ohe_test))) + + for covs in self.static_categorical_covariates.values(): + ohe = one_hot_encoder.fit_transform(np.array(covs)[:, np.newaxis]) + x_train.append(_repeat(ohe, self.train_lens)) + x_test.append(_repeat(ohe, self.test_lens)) + + x_train = np.concatenate(x_train, axis=1) + x_test = np.concatenate(x_test, axis=1) + + if use_intercept: + x_train = np.pad(x_train, ((0, 0), (1, 0)), constant_values=1.0) + x_test = np.pad(x_test, ((0, 0), (1, 0)), constant_values=1.0) + + return _unnest(self.targets), x_train, x_test + + def fit(self) -> Any: + raise NotImplementedError("Fit is not implemented.") + + +class BatchedInContextXRegLinear(BatchedInContextXRegBase): + """Linear in-context regression model.""" + + def fit( + self, + ridge: float = 0.0, + one_hot_encoder_drop: str | None = "first", + use_intercept: bool = True, + force_on_cpu: bool = False, + max_rows_per_col: int = 0, + max_rows_per_col_sample_seed: int = 42, + debug_info: bool = False, + assert_covariates: bool = False, + assert_covariate_shapes: bool = False, + ) -> ( + list[np.ndarray] + | tuple[list[np.ndarray], list[np.ndarray], jax.Array, jax.Array, jax.Array] + ): + """Fits a linear model for in-context regression. + + Args: + ridge: A non-negative value for specifying the ridge regression penalty. + If 0 is provided, fallback to ordinary least squares. Note this penalty + is added to the normalized covariate matrix. + one_hot_encoder_drop: Which drop strategy to use for the one hot encoder. + use_intercept: Whether to prepare an intercept (all 1) column in the + matrices. + force_on_cpu: Whether to force execution on cpu for accelerator machines. + max_rows_per_col: How many rows to subsample per column. 0 for no + subsampling. This is for speeding up model fitting. + max_rows_per_col_sample_seed: The seed for the subsampling if needed by + `max_rows_per_col`. + debug_info: Whether to return debug info. + assert_covariates: Whether to assert the validity of the covariate inputs. + assert_covariate_shapes: Whether to assert the shapes of the covariate + inputs when `assert_covariates` is True. + + Returns: + If `debug_info` is False: + The linear fits on the horizon. + If `debug_info` is True: + A tuple of: + - the linear fits on the horizon, + - the linear fits on the context, + - the flattened target vector, + - the covariate matrix for the context, and + - the covariate matrix for the horizon. + """ + flat_targets, x_train_raw, x_test = self.create_covariate_matrix( + one_hot_encoder_drop=one_hot_encoder_drop, + use_intercept=use_intercept, + assert_covariates=assert_covariates, + assert_covariate_shapes=assert_covariate_shapes, + ) + + x_train = x_train_raw.copy() + if max_rows_per_col: + nrows, ncols = x_train.shape + if nrows > (w := ncols * max_rows_per_col): + subsample = jax.random.choice( + jax.random.PRNGKey(max_rows_per_col_sample_seed), + nrows, + (w,), + replace=False, + ) + x_train = x_train[subsample] + flat_targets = flat_targets[subsample] + + device = jax.devices("cpu")[0] if force_on_cpu else None + # Runs jitted version of the solvers which are quicker at the cost of + # running jitting during the first time calling. Re-jitting happens whenever + # new (padded) shapes are encountered. + # Ocassionally it helps with the speed and the accuracy if we force single + # thread execution on cpu for accelerator machines: + # 1. Avoid moving data to accelarator memory. + # 2. Avoid precision loss if any. + with jax.default_device(device): + x_train_raw = _to_padded_jax_array(x_train_raw) + x_train = _to_padded_jax_array(x_train) + flat_targets = _to_padded_jax_array(flat_targets) + x_test = _to_padded_jax_array(x_test) + beta_hat = ( + jnp.linalg.pinv( + x_train.T @ x_train + ridge * jnp.eye(x_train.shape[1]), + hermitian=True, + ) + @ x_train.T + @ flat_targets + ) + y_hat = x_test @ beta_hat + y_hat_context = x_train_raw @ beta_hat if debug_info else None + + outputs = [] + outputs_context = [] + + # Reconstruct the ragged 2-dim batched forecasts from flattened linear fits. + train_index, test_index = 0, 0 + for train_index_delta, test_index_delta in zip(self.train_lens, self.test_lens): + outputs.append( + np.array(y_hat[test_index : (test_index + test_index_delta)]) + ) + if debug_info: + outputs_context.append( + np.array( + y_hat_context[train_index : (train_index + train_index_delta)] + ) + ) + train_index += train_index_delta + test_index += test_index_delta + + if debug_info: + return outputs, outputs_context, flat_targets, x_train, x_test + else: + return outputs From 119b2efe79558c37da9e123066654d35a8b5ab75 Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Sun, 22 Sep 2024 19:07:03 -0700 Subject: [PATCH 012/242] remove covariate forecasting --- .../{modeling_timesfm.py => timesfm.py} | 0 .../models/timesfm/timesfm_base.py | 232 -------- src/transformers/models/timesfm/xreg_lib.py | 520 ------------------ 3 files changed, 752 deletions(-) rename src/transformers/models/timesfm/{modeling_timesfm.py => timesfm.py} (100%) delete mode 100644 src/transformers/models/timesfm/xreg_lib.py diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/timesfm.py similarity index 100% rename from src/transformers/models/timesfm/modeling_timesfm.py rename to src/transformers/models/timesfm/timesfm.py diff --git a/src/transformers/models/timesfm/timesfm_base.py b/src/transformers/models/timesfm/timesfm_base.py index c5f113ee6000..7c0c756e6847 100644 --- a/src/transformers/models/timesfm/timesfm_base.py +++ b/src/transformers/models/timesfm/timesfm_base.py @@ -25,10 +25,6 @@ from utilsforecast.processing import make_future_dataframe from configuration_timesfm import TimesFMConfig -import xreg_lib - -Category = xreg_lib.Category -XRegMode = xreg_lib.XRegMode _TOL = 1e-6 DEFAULT_QUANTILES = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9) @@ -245,234 +241,6 @@ def forecast( """ raise NotImplementedError("`forecast` is not implemented.") - def forecast_with_covariates( - self, - inputs: list[Sequence[float]], - dynamic_numerical_covariates: ( - dict[str, Sequence[Sequence[float]]] | None - ) = None, - dynamic_categorical_covariates: ( - dict[str, Sequence[Sequence[Category]]] | None - ) = None, - static_numerical_covariates: dict[str, Sequence[float]] | None = None, - static_categorical_covariates: dict[str, Sequence[Category]] | None = None, - freq: Sequence[int] | None = None, - window_size: int | None = None, - forecast_context_len: int | None = None, - xreg_mode: XRegMode = "xreg + timesfm", - normalize_xreg_target_per_input: bool = True, - ridge: float = 0.0, - max_rows_per_col: int = 0, - force_on_cpu: bool = False, - ): - """Forecasts on a list of time series with covariates. - - To optimize inference speed, avoid string valued categorical covariates. - - Args: - inputs: A list of time series forecast contexts. Each context time series - should be in a format convertible to JTensor by `jnp.array`. - dynamic_numerical_covariates: A dict of dynamic numerical covariates. - dynamic_categorical_covariates: A dict of dynamic categorical covariates. - static_numerical_covariates: A dict of static numerical covariates. - static_categorical_covariates: A dict of static categorical covariates. - freq: frequency of each context time series. 0 for high frequency - (default), 1 for medium, and 2 for low. Notice this is different from - the `freq` required by `forecast_on_df`. - window_size: window size of trend + residual decomposition. If None then - we do not do decomposition. - forecast_context_len: optional max context length. - xreg_mode: one of "xreg + timesfm" or "timesfm + xreg". "xreg + timesfm" - fits a model on the residuals of the TimesFM forecast. "timesfm + xreg" - fits a model on the targets then forecasts on the residuals via TimesFM. - normalize_xreg_target_per_input: whether to normalize the xreg target per - input in the given batch. - ridge: ridge penalty for the linear model. - max_rows_per_col: max number of rows per column for the linear model. - force_on_cpu: whether to force running on cpu for the linear model. - - Returns: - A tuple of two lists. The first is the outputs of the model. The second is - the outputs of the xreg. - """ - - # Verify and bookkeep covariates. - if not ( - dynamic_numerical_covariates - or dynamic_categorical_covariates - or static_numerical_covariates - or static_categorical_covariates - ): - raise ValueError( - "At least one of dynamic_numerical_covariates," - " dynamic_categorical_covariates, static_numerical_covariates," - " static_categorical_covariates must be set." - ) - - # Track the lengths of (1) each input, (2) the part that can be used in the - # linear model, and (3) the horizon. - input_lens, train_lens, test_lens = [], [], [] - - for i, input_ts in enumerate(inputs): - input_len = len(input_ts) - input_lens.append(input_len) - - if xreg_mode == "timesfm + xreg": - # For fitting residuals, no TimesFM forecast on the first patch. - train_lens.append(max(0, input_len - self.input_patch_len)) - elif xreg_mode == "xreg + timesfm": - train_lens.append(input_len) - else: - raise ValueError(f"Unsupported mode: {xreg_mode}") - - if dynamic_numerical_covariates: - test_lens.append( - len(list(dynamic_numerical_covariates.values())[0][i]) - input_len - ) - elif dynamic_categorical_covariates: - test_lens.append( - len(list(dynamic_categorical_covariates.values())[0][i]) - input_len - ) - else: - test_lens.append(self.horizon_len) - - if test_lens[-1] > self.horizon_len: - raise ValueError( - "Forecast requested longer horizon than the model definition " - f"supports: {test_lens[-1]} vs {self.horizon_len}." - ) - - # Prepare the covariates into train and test. - train_dynamic_numerical_covariates = collections.defaultdict(list) - test_dynamic_numerical_covariates = collections.defaultdict(list) - train_dynamic_categorical_covariates = collections.defaultdict(list) - test_dynamic_categorical_covariates = collections.defaultdict(list) - for covariates, train_covariates, test_covariates in ( - ( - dynamic_numerical_covariates, - train_dynamic_numerical_covariates, - test_dynamic_numerical_covariates, - ), - ( - dynamic_categorical_covariates, - train_dynamic_categorical_covariates, - test_dynamic_categorical_covariates, - ), - ): - if not covariates: - continue - for covariate_name, covariate_values in covariates.items(): - for input_len, train_len, covariate_value in zip( - input_lens, train_lens, covariate_values - ): - train_covariates[covariate_name].append( - covariate_value[(input_len - train_len) : input_len] - ) - test_covariates[covariate_name].append(covariate_value[input_len:]) - - # Fit models. - if xreg_mode == "timesfm + xreg": - # Forecast via TimesFM then fit a model on the residuals. - mean_outputs, _ = self.forecast( - inputs, - freq, - window_size, - forecast_context_len, - return_forecast_on_context=True, - ) - targets = [ - ( - np.array(input_ts)[-train_len:] - - mean_output[ - (self._horizon_start - train_len) : self._horizon_start - ] - ) - for input_ts, mean_output, train_len in zip( - inputs, mean_outputs, train_lens - ) - ] - per_instance_stats = None - if normalize_xreg_target_per_input: - targets, per_instance_stats = normalize(targets) - xregs = xreg_lib.BatchedInContextXRegLinear( - targets=targets, - train_lens=train_lens, - test_lens=test_lens, - train_dynamic_numerical_covariates=train_dynamic_numerical_covariates, - test_dynamic_numerical_covariates=test_dynamic_numerical_covariates, - train_dynamic_categorical_covariates=train_dynamic_categorical_covariates, - test_dynamic_categorical_covariates=test_dynamic_categorical_covariates, - static_numerical_covariates=static_numerical_covariates, - static_categorical_covariates=static_categorical_covariates, - ).fit( - ridge=ridge, - one_hot_encoder_drop=None if ridge > 0 else "first", - max_rows_per_col=max_rows_per_col, - force_on_cpu=force_on_cpu, - debug_info=False, - assert_covariates=True, - assert_covariate_shapes=True, - ) - if normalize_xreg_target_per_input: - xregs = renormalize(xregs, per_instance_stats) - outputs = [ - ( - mean_output[self._horizon_start : (self._horizon_start + test_len)] - + xreg - ) - for mean_output, test_len, xreg in zip(mean_outputs, test_lens, xregs) - ] - - else: - # Fit a model on the targets then forecast on the residuals via TimesFM. - targets = [ - np.array(input_ts)[-train_len:] - for input_ts, train_len in zip(inputs, train_lens) - ] - per_instance_stats = None - if normalize_xreg_target_per_input: - targets, per_instance_stats = normalize(targets) - xregs, xregs_on_context, _, _, _ = xreg_lib.BatchedInContextXRegLinear( - targets=targets, - train_lens=train_lens, - test_lens=test_lens, - train_dynamic_numerical_covariates=train_dynamic_numerical_covariates, - test_dynamic_numerical_covariates=test_dynamic_numerical_covariates, - train_dynamic_categorical_covariates=train_dynamic_categorical_covariates, - test_dynamic_categorical_covariates=test_dynamic_categorical_covariates, - static_numerical_covariates=static_numerical_covariates, - static_categorical_covariates=static_categorical_covariates, - ).fit( - ridge=ridge, - one_hot_encoder_drop=None if ridge > 0 else "first", - max_rows_per_col=max_rows_per_col, - force_on_cpu=force_on_cpu, - debug_info=True, - assert_covariates=True, - assert_covariate_shapes=True, - ) - mean_outputs, _ = self.forecast( - [ - target - xreg_on_context - for target, xreg_on_context in zip(targets, xregs_on_context) - ], - freq, - window_size, - forecast_context_len, - return_forecast_on_context=True, - ) - outputs = [ - ( - mean_output[self._horizon_start : (self._horizon_start + test_len)] - + xreg - ) - for mean_output, test_len, xreg in zip(mean_outputs, test_lens, xregs) - ] - if normalize_xreg_target_per_input: - outputs = renormalize(outputs, per_instance_stats) - - return outputs, xregs - def forecast_on_df( self, inputs: pd.DataFrame, diff --git a/src/transformers/models/timesfm/xreg_lib.py b/src/transformers/models/timesfm/xreg_lib.py deleted file mode 100644 index 1c7d253990ca..000000000000 --- a/src/transformers/models/timesfm/xreg_lib.py +++ /dev/null @@ -1,520 +0,0 @@ -# coding=utf-8 -# Copyright 2024 Google LLC and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Helper functions for in-context covariates and regression.""" - -import itertools -import math -from typing import Any, Iterable, Literal, Mapping, Sequence - -import jax -import jax.numpy as jnp -import numpy as np -from sklearn import preprocessing - -Category = int | str - -_TOL = 1e-6 -XRegMode = Literal["timesfm + xreg", "xreg + timesfm"] - - -def _unnest(nested: Sequence[Sequence[Any]]) -> np.ndarray: - return np.array(list(itertools.chain.from_iterable(nested))) - - -def _repeat(elements: Iterable[Any], counts: Iterable[int]) -> np.ndarray: - return np.array( - list(itertools.chain.from_iterable(map(itertools.repeat, elements, counts))) - ) - - -def _to_padded_jax_array(x: np.ndarray) -> jax.Array: - if x.ndim == 1: - (i,) = x.shape - di = 2 ** math.ceil(math.log2(i)) - i - return jnp.pad(x, ((0, di),), mode="constant", constant_values=0.0) - elif x.ndim == 2: - i, j = x.shape - di = 2 ** math.ceil(math.log2(i)) - i - dj = 2 ** math.ceil(math.log2(j)) - j - return jnp.pad(x, ((0, di), (0, dj)), mode="constant", constant_values=0.0) - else: - raise ValueError(f"Unsupported array shape: {x.shape}") - - -class BatchedInContextXRegBase: - """Helper class for in-context regression covariate formatting. - - Attributes: - targets: List of targets (responses) of the in-context regression. - train_lens: List of lengths of each target vector from the context. - test_lens: List of lengths of each forecast horizon. - train_dynamic_numerical_covariates: Dict of covariate names mapping to the - dynamic numerical covariates of each forecast task on the context. Their - lengths should match the corresponding lengths in `train_lens`. - train_dynamic_categorical_covariates: Dict of covariate names mapping to the - dynamic categorical covariates of each forecast task on the context. Their - lengths should match the corresponding lengths in `train_lens`. - test_dynamic_numerical_covariates: Dict of covariate names mapping to the - dynamic numerical covariates of each forecast task on the horizon. Their - lengths should match the corresponding lengths in `test_lens`. - test_dynamic_categorical_covariates: Dict of covariate names mapping to the - dynamic categorical covariates of each forecast task on the horizon. Their - lengths should match the corresponding lengths in `test_lens`. - static_numerical_covariates: Dict of covariate names mapping to the static - numerical covariates of each forecast task. - static_categorical_covariates: Dict of covariate names mapping to the static - categorical covariates of each forecast task. - """ - - def __init__( - self, - targets: Sequence[Sequence[float]], - train_lens: Sequence[int], - test_lens: Sequence[int], - train_dynamic_numerical_covariates: ( - Mapping[str, Sequence[Sequence[float]]] | None - ) = None, - train_dynamic_categorical_covariates: ( - Mapping[str, Sequence[Sequence[Category]]] | None - ) = None, - test_dynamic_numerical_covariates: ( - Mapping[str, Sequence[Sequence[float]]] | None - ) = None, - test_dynamic_categorical_covariates: ( - Mapping[str, Sequence[Sequence[Category]]] | None - ) = None, - static_numerical_covariates: Mapping[str, Sequence[float]] | None = None, - static_categorical_covariates: Mapping[str, Sequence[Category]] | None = None, - ) -> None: - """Initializes with the exogenous covariate inputs. - - Here we use model fitting language to refer to the context as 'train' and - the horizon as 'test'. We assume batched inputs. To properly format the - request: - - - `train_lens` represents the contexts in the batch. Targets and all train - dynamic covariates should have the same lengths as the corresponding - elements - in `train_lens`. Notice each `train_len` can be different from the exact - length of the corresponding context depending on how much of the context is - used for fitting the in-context model. - - `test_lens` represents the horizon lengths in the batch. All tesdt - dynamic - covariates should have the same lengths as the corresponding elements in - `test_lens`. - - Static covariates should be one for each input. - - For train and test dynamic covariates, they should have the same - covariate - names. - - Pass an empty dict {} for a covariate type if it is not present. - - Example: - Here is a set of valid inputs whose schema can be used for reference. - ``` - targets = [ - [0.0, 0.1, 0.2], - [0.0, 0.1, 0.2, 0.3], - ] # Two inputs in this batch. - train_lens = [3, 4] - test_lens = [2, 5] # Forecast horizons 2 and 5 respectively. - train_dynamic_numerical_covariates = { - "cov_1_dn": [[0.0, 0.5, 1.0], [0.0, 0.5, 1.0, 1.5]], - "cov_2_dn": [[0.0, 1.5, 1.0], [0.0, 1.5, 1.0, 2.5]], - } # Each train dynamic covariate has 3 and 4 elements respectively. - test_dynamic_numerical_covariates = { - "cov_1_dn": [[0.1, 0.6], [0.1, 0.6, 1.1, 1.6, 2.4]], - "cov_2_dn": [[0.1, 1.1], [0.1, 1.6, 1.1, 2.6, 10.0]], - } # Each test dynamic covariate has 2 and 5 elements respectively. - train_dynamic_categorical_covariates = { - "cov_1_dc": [[0, 1, 0], [0, 1, 2, 3]], - "cov_2_dc": [["good", "bad", "good"], ["good", "good", "bad", - "bad"]], - } - test_dynamic_categorical_covariates = { - "cov_1_dc": [[1, 0], [1, 0, 2, 3, 1]], - "cov_2_dc": [["bad", "good"], ["bad", "bad", "bad", "bad", "bad"]], - } - static_numerical_covariates = { - "cov_1_sn": [0.0, 3.0], - "cov_2_sn": [2.0, 1.0], - "cov_3_sn": [1.0, 2.0], - } # Each static covariate has 1 element for each input. - static_categorical_covariates = { - "cov_1_sc": ["apple", "orange"], - "cov_2_sc": [2, 3], - } - ``` - - Args: - targets: List of targets (responses) of the in-context regression. - train_lens: List of lengths of each target vector from the context. - test_lens: List of lengths of each forecast horizon. - train_dynamic_numerical_covariates: Dict of covariate names mapping to the - dynamic numerical covariates of each forecast task on the context. Their - lengths should match the corresponding lengths in `train_lens`. - train_dynamic_categorical_covariates: Dict of covariate names mapping to - the dynamic categorical covariates of each forecast task on the context. - Their lengths should match the corresponding lengths in `train_lens`. - test_dynamic_numerical_covariates: Dict of covariate names mapping to the - dynamic numerical covariates of each forecast task on the horizon. Their - lengths should match the corresponding lengths in `test_lens`. - test_dynamic_categorical_covariates: Dict of covariate names mapping to - the dynamic categorical covariates of each forecast task on the horizon. - Their lengths should match the corresponding lengths in `test_lens`. - static_numerical_covariates: Dict of covariate names mapping to the static - numerical covariates of each forecast task. - static_categorical_covariates: Dict of covariate names mapping to the - static categorical covariates of each forecast task. - """ - self.targets = targets - self.train_lens = train_lens - self.test_lens = test_lens - self.train_dynamic_numerical_covariates = ( - train_dynamic_numerical_covariates or {} - ) - self.train_dynamic_categorical_covariates = ( - train_dynamic_categorical_covariates or {} - ) - self.test_dynamic_numerical_covariates = test_dynamic_numerical_covariates or {} - self.test_dynamic_categorical_covariates = ( - test_dynamic_categorical_covariates or {} - ) - self.static_numerical_covariates = static_numerical_covariates or {} - self.static_categorical_covariates = static_categorical_covariates or {} - - def _assert_covariates(self, assert_covariate_shapes: bool = False) -> None: - """Verifies the validity of the covariate inputs.""" - - # Check presence. - if ( - self.train_dynamic_numerical_covariates - and not self.test_dynamic_numerical_covariates - ) or ( - not self.train_dynamic_numerical_covariates - and self.test_dynamic_numerical_covariates - ): - raise ValueError( - "train_dynamic_numerical_covariates and" - " test_dynamic_numerical_covariates must be both present or both" - " absent." - ) - - if ( - self.train_dynamic_categorical_covariates - and not self.test_dynamic_categorical_covariates - ) or ( - not self.train_dynamic_categorical_covariates - and self.test_dynamic_categorical_covariates - ): - raise ValueError( - "train_dynamic_categorical_covariates and" - " test_dynamic_categorical_covariates must be both present or both" - " absent." - ) - - # Check keys. - for dict_a, dict_b, dict_a_name, dict_b_name in ( - ( - self.train_dynamic_numerical_covariates, - self.test_dynamic_numerical_covariates, - "train_dynamic_numerical_covariates", - "test_dynamic_numerical_covariates", - ), - ( - self.train_dynamic_categorical_covariates, - self.test_dynamic_categorical_covariates, - "train_dynamic_categorical_covariates", - "test_dynamic_categorical_covariates", - ), - ): - if w := set(dict_a.keys()) - set(dict_b.keys()): - raise ValueError( - f"{dict_a_name} has keys not present in {dict_b_name}: {w}" - ) - if w := set(dict_b.keys()) - set(dict_a.keys()): - raise ValueError( - f"{dict_b_name} has keys not present in {dict_a_name}: {w}" - ) - - # Check shapes. - if assert_covariate_shapes: - if len(self.targets) != len(self.train_lens): - raise ValueError( - "targets and train_lens must have the same number of elements." - ) - - if len(self.train_lens) != len(self.test_lens): - raise ValueError( - "train_lens and test_lens must have the same number of elements." - ) - - for i, (target, train_len) in enumerate(zip(self.targets, self.train_lens)): - if len(target) != train_len: - raise ValueError( - f"targets[{i}] has length {len(target)} != expected {train_len}." - ) - - for key, values in self.static_numerical_covariates.items(): - if len(values) != len(self.train_lens): - raise ValueError( - f"static_numerical_covariates has key {key} with number of" - f" examples {len(values)} != expected {len(self.train_lens)}." - ) - - for key, values in self.static_categorical_covariates.items(): - if len(values) != len(self.train_lens): - raise ValueError( - f"static_categorical_covariates has key {key} with number of" - f" examples {len(values)} != expected {len(self.train_lens)}." - ) - - for lens, dict_cov, dict_cov_name in ( - ( - self.train_lens, - self.train_dynamic_numerical_covariates, - "train_dynamic_numerical_covariates", - ), - ( - self.train_lens, - self.train_dynamic_categorical_covariates, - "train_dynamic_categorical_covariates", - ), - ( - self.test_lens, - self.test_dynamic_numerical_covariates, - "test_dynamic_numerical_covariates", - ), - ( - self.test_lens, - self.test_dynamic_categorical_covariates, - "test_dynamic_categorical_covariates", - ), - ): - for key, cov_values in dict_cov.items(): - if len(cov_values) != len(lens): - raise ValueError( - f"{dict_cov_name} has key {key} with number of examples" - f" {len(cov_values)} != expected {len(lens)}." - ) - for i, cov_value in enumerate(cov_values): - if len(cov_value) != lens[i]: - raise ValueError( - f"{dict_cov_name} has key {key} with its {i}-th example" - f" length {len(cov_value)} != expected {lens[i]}." - ) - - def create_covariate_matrix( - self, - one_hot_encoder_drop: str | None = "first", - use_intercept: bool = True, - assert_covariates: bool = False, - assert_covariate_shapes: bool = False, - ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - """Creates target vector and covariate matrices for in context regression. - - Here we use model fitting language to refer to the context as 'train' and - the horizon as 'test'. - - Args: - one_hot_encoder_drop: Which drop strategy to use for the one hot encoder. - use_intercept: Whether to prepare an intercept (all 1) column in the - matrices. - assert_covariates: Whether to assert the validity of the covariate inputs. - assert_covariate_shapes: Whether to assert the shapes of the covariate - inputs when `assert_covariates` is True. - - Returns: - A tuple of the target vector, the covariate matrix for the context, and - the covariate matrix for the horizon. - """ - if assert_covariates: - self._assert_covariates(assert_covariate_shapes) - - x_train, x_test = [], [] - - # Numerical features. - for name in sorted(self.train_dynamic_numerical_covariates): - x_train.append( - _unnest(self.train_dynamic_numerical_covariates[name])[:, np.newaxis] - ) - x_test.append( - _unnest(self.test_dynamic_numerical_covariates[name])[:, np.newaxis] - ) - - for covs in self.static_numerical_covariates.values(): - x_train.append(_repeat(covs, self.train_lens)[:, np.newaxis]) - x_test.append(_repeat(covs, self.test_lens)[:, np.newaxis]) - - if x_train: - x_train = np.concatenate(x_train, axis=1) - x_test = np.concatenate(x_test, axis=1) - - # Normalize for robustness. - x_mean = np.mean(x_train, axis=0, keepdims=True) - x_std = np.where( - (w := np.std(x_train, axis=0, keepdims=True)) > _TOL, w, 1.0 - ) - x_train = [(x_train - x_mean) / x_std] - x_test = [(x_test - x_mean) / x_std] - - # Categorical features. Encode one by one. - one_hot_encoder = preprocessing.OneHotEncoder( - drop=one_hot_encoder_drop, - sparse_output=False, - handle_unknown="ignore", - ) - for name in sorted(self.train_dynamic_categorical_covariates.keys()): - ohe_train = _unnest(self.train_dynamic_categorical_covariates[name])[ - :, np.newaxis - ] - ohe_test = _unnest(self.test_dynamic_categorical_covariates[name])[ - :, np.newaxis - ] - x_train.append(np.array(one_hot_encoder.fit_transform(ohe_train))) - x_test.append(np.array(one_hot_encoder.transform(ohe_test))) - - for covs in self.static_categorical_covariates.values(): - ohe = one_hot_encoder.fit_transform(np.array(covs)[:, np.newaxis]) - x_train.append(_repeat(ohe, self.train_lens)) - x_test.append(_repeat(ohe, self.test_lens)) - - x_train = np.concatenate(x_train, axis=1) - x_test = np.concatenate(x_test, axis=1) - - if use_intercept: - x_train = np.pad(x_train, ((0, 0), (1, 0)), constant_values=1.0) - x_test = np.pad(x_test, ((0, 0), (1, 0)), constant_values=1.0) - - return _unnest(self.targets), x_train, x_test - - def fit(self) -> Any: - raise NotImplementedError("Fit is not implemented.") - - -class BatchedInContextXRegLinear(BatchedInContextXRegBase): - """Linear in-context regression model.""" - - def fit( - self, - ridge: float = 0.0, - one_hot_encoder_drop: str | None = "first", - use_intercept: bool = True, - force_on_cpu: bool = False, - max_rows_per_col: int = 0, - max_rows_per_col_sample_seed: int = 42, - debug_info: bool = False, - assert_covariates: bool = False, - assert_covariate_shapes: bool = False, - ) -> ( - list[np.ndarray] - | tuple[list[np.ndarray], list[np.ndarray], jax.Array, jax.Array, jax.Array] - ): - """Fits a linear model for in-context regression. - - Args: - ridge: A non-negative value for specifying the ridge regression penalty. - If 0 is provided, fallback to ordinary least squares. Note this penalty - is added to the normalized covariate matrix. - one_hot_encoder_drop: Which drop strategy to use for the one hot encoder. - use_intercept: Whether to prepare an intercept (all 1) column in the - matrices. - force_on_cpu: Whether to force execution on cpu for accelerator machines. - max_rows_per_col: How many rows to subsample per column. 0 for no - subsampling. This is for speeding up model fitting. - max_rows_per_col_sample_seed: The seed for the subsampling if needed by - `max_rows_per_col`. - debug_info: Whether to return debug info. - assert_covariates: Whether to assert the validity of the covariate inputs. - assert_covariate_shapes: Whether to assert the shapes of the covariate - inputs when `assert_covariates` is True. - - Returns: - If `debug_info` is False: - The linear fits on the horizon. - If `debug_info` is True: - A tuple of: - - the linear fits on the horizon, - - the linear fits on the context, - - the flattened target vector, - - the covariate matrix for the context, and - - the covariate matrix for the horizon. - """ - flat_targets, x_train_raw, x_test = self.create_covariate_matrix( - one_hot_encoder_drop=one_hot_encoder_drop, - use_intercept=use_intercept, - assert_covariates=assert_covariates, - assert_covariate_shapes=assert_covariate_shapes, - ) - - x_train = x_train_raw.copy() - if max_rows_per_col: - nrows, ncols = x_train.shape - if nrows > (w := ncols * max_rows_per_col): - subsample = jax.random.choice( - jax.random.PRNGKey(max_rows_per_col_sample_seed), - nrows, - (w,), - replace=False, - ) - x_train = x_train[subsample] - flat_targets = flat_targets[subsample] - - device = jax.devices("cpu")[0] if force_on_cpu else None - # Runs jitted version of the solvers which are quicker at the cost of - # running jitting during the first time calling. Re-jitting happens whenever - # new (padded) shapes are encountered. - # Ocassionally it helps with the speed and the accuracy if we force single - # thread execution on cpu for accelerator machines: - # 1. Avoid moving data to accelarator memory. - # 2. Avoid precision loss if any. - with jax.default_device(device): - x_train_raw = _to_padded_jax_array(x_train_raw) - x_train = _to_padded_jax_array(x_train) - flat_targets = _to_padded_jax_array(flat_targets) - x_test = _to_padded_jax_array(x_test) - beta_hat = ( - jnp.linalg.pinv( - x_train.T @ x_train + ridge * jnp.eye(x_train.shape[1]), - hermitian=True, - ) - @ x_train.T - @ flat_targets - ) - y_hat = x_test @ beta_hat - y_hat_context = x_train_raw @ beta_hat if debug_info else None - - outputs = [] - outputs_context = [] - - # Reconstruct the ragged 2-dim batched forecasts from flattened linear fits. - train_index, test_index = 0, 0 - for train_index_delta, test_index_delta in zip(self.train_lens, self.test_lens): - outputs.append( - np.array(y_hat[test_index : (test_index + test_index_delta)]) - ) - if debug_info: - outputs_context.append( - np.array( - y_hat_context[train_index : (train_index + train_index_delta)] - ) - ) - train_index += train_index_delta - test_index += test_index_delta - - if debug_info: - return outputs, outputs_context, flat_targets, x_train, x_test - else: - return outputs From ce7fd1b122d4bb05ebae1708efc4abc8f728366c Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Mon, 30 Sep 2024 13:03:37 -0700 Subject: [PATCH 013/242] Adapting TimesFM to HF format --- src/transformers/models/timesfm/__init__.py | 6 +- .../models/timesfm/configuration_timesfm.py | 6 +- .../models/timesfm/modeling_timesfm.py | 626 ++++++++++++++++++ .../models/timesfm/patched_decoder.py | 22 +- src/transformers/models/timesfm/timesfm.py | 202 ------ .../models/timesfm/timesfm_base.py | 340 ---------- 6 files changed, 643 insertions(+), 559 deletions(-) create mode 100644 src/transformers/models/timesfm/modeling_timesfm.py delete mode 100644 src/transformers/models/timesfm/timesfm.py delete mode 100644 src/transformers/models/timesfm/timesfm_base.py diff --git a/src/transformers/models/timesfm/__init__.py b/src/transformers/models/timesfm/__init__.py index baa30b11af21..fe1a08da2678 100644 --- a/src/transformers/models/timesfm/__init__.py +++ b/src/transformers/models/timesfm/__init__.py @@ -29,8 +29,7 @@ except OptionalDependencyNotAvailable: pass else: - _import_structure["modeling_timesfm"] = [ - "TimesFMForPrediction", + _import_structure["timesfm"] = [ "TimesFMModel", "TimesFMPreTrainedModel", ] @@ -44,8 +43,7 @@ except OptionalDependencyNotAvailable: pass else: - from .modeling_timesfm import ( - TimesFMForPrediction, + from .timesfm import ( TimesFMModel, TimesFMPreTrainedModel, ) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index de82a874771b..29948593aff5 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -71,6 +71,8 @@ class TimesFMConfig(PretrainedConfig): initializer_factor (`float`, *optional*, defaults to 1): A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing). + backend (`str`, *optional*, defaults to `"gpu"`): + The backend to use for the model. Can be either `"gpu"` or `"cpu"`. """ model_type = "timesfm" @@ -97,8 +99,9 @@ def __init__( quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], pad_val: float = 1123581321.0, use_positional_embedding: bool = True, - per_core_batch_size: int = 32, + per_core_batch_size: int = 32, initializer_factor: float = 1.0, + backend: str = "gpu", **kwargs, ): self.patch_len = patch_len @@ -117,6 +120,7 @@ def __init__( self.use_positional_embedding = use_positional_embedding self.per_core_batch_size = per_core_batch_size self.initializer_factor = initializer_factor + self.backend = backend super().__init__( **kwargs, diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py new file mode 100644 index 000000000000..f2df0c061129 --- /dev/null +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -0,0 +1,626 @@ +# coding=utf-8 +# Copyright 2024 Google LLC and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch TimesFM model.""" + + +#################################################### +# PyTorch Models are constructed by sub-classing +# - torch.nn.Module for the layers and +# - PreTrainedModel for the models (it-self a sub-class of nn.Module) +#################################################### + + +import dataclasses +import logging +import multiprocessing +from typing import Any, Sequence +from os import path +import pandas as pd +import numpy as np +import torch +import torch.nn as nn +from huggingface_hub import snapshot_download +from ...modeling_utils import PreTrainedModel + +import patched_decoder as ppd +from utilsforecast.processing import make_future_dataframe +from configuration_timesfm import TimesFMConfig + + +def process_group(key, group, value_name, forecast_context_len): + group = group.tail(forecast_context_len) + return np.array(group[value_name], dtype=np.float32), key + + +def moving_average(arr, window_size): + """Calculates the moving average using NumPy's convolution function.""" + # Pad with zeros to handle initial window positions + arr_padded = np.pad(arr, (window_size - 1, 0), "constant") + smoothed_arr = np.convolve(arr_padded, np.ones(window_size), "valid") / window_size + return [smoothed_arr, arr - smoothed_arr] + + +def freq_map(freq: str): + """Returns the frequency map for the given frequency string.""" + freq = str.upper(freq) + if ( + freq.endswith("H") + or freq.endswith("T") + or freq.endswith("MIN") + or freq.endswith("D") + or freq.endswith("B") + or freq.endswith("U") + ): + return 0 + elif freq.endswith(("W", "M", "MS")): + return 1 + elif freq.endswith("Y") or freq.endswith("Q"): + return 2 + else: + raise ValueError(f"Invalid frequency: {freq}") + + +@dataclasses.dataclass(kw_only=True) +class TimesFmCheckpoint: + """Checkpoint used to initialize a TimesFM model for inference. + + Attributes: + version: Version of the checkpoint, e.g. "jax", "torch", "tensorflow", etc. + The factory will create the corresponding TimesFm inference class based on + this version. + path: Path to the checkpoint. + type: If provided, type of the checkpoint used by the specific checkpoint + loader per version. + step: If provided, step of the checkpoint. + """ + + version: str = "torch" + path: str | None = None + huggingface_repo_id: str | None = None + type: Any = None + step: int | None = None + + +class TimesFmBase: + """Base TimesFM forecast API for inference. + + This class is the scaffolding for calling TimesFM forecast. To properly use: + 1. Create an instance with the correct hyperparameters of a TimesFM model. + 2. Call `load_from_checkpoint` to load a compatible checkpoint. + 3. Call `forecast` for inference. + """ + + def _logging(self, s): + print(s) + + def __init__(self, hparams: TimesFMConfig) -> None: + """Initializes the TimesFM forecast API. + + Args: + hparams: Hyperparameters of the model. + checkpoint: Checkpoint to load. Notice `checkpoint.version` will decide + which TimesFM version to use. + """ + self.hparams = hparams + + # Expand hparams for conciseness within the model code. + self.context_len = hparams.context_len + self.horizon_len = hparams.horizon_len + self.input_patch_len = hparams.patch_len + self.output_patch_len = hparams.horizon_len + self.num_layers = hparams.num_layers + self.model_dims = hparams.model_dim + self.backend = hparams.backend + self.quantiles = hparams.quantiles + self.num_heads = hparams.num_heads + + # Rewrite these values in subclasses for SPMD. + self.num_cores = 1 + self.per_core_batch_size = hparams.per_core_batch_size + self.global_batch_size = hparams.per_core_batch_size + self._horizon_start = self.context_len - self.input_patch_len + + def load_from_checkpoint(self, checkpoint: TimesFmCheckpoint) -> None: + """Loads a checkpoint and compiles the decoder.""" + raise NotImplementedError("`load_from_checkpoint` is not implemented.") + + def _preprocess( + self, inputs: Sequence[np.array], freq: Sequence[int] + ) -> tuple[np.array, np.array, int]: + """Formats and pads raw inputs to feed into the model. + + This function both pads each time series to match the context length, and + pads the inputs to meet the SPMD shape requirement. + + Args: + inputs: A list of 1d JTensors. Each JTensor is the context time series of + a single forecast task. + freq: list of frequencies + + Returns: + A tuple of: + - the padded input time series to meet the model required context. + - the padding indicator. + - the number of padded examples for SPMD so that each core has the same + number (a multiple of `batch_size`) of examples. + """ + + input_ts, input_padding, inp_freq = [], [], [] + + pmap_pad = ( + (len(inputs) - 1) // self.global_batch_size + 1 + ) * self.global_batch_size - len(inputs) + + for i, ts in enumerate(inputs): + input_len = ts.shape[0] + padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float) + if input_len < self.context_len: + num_front_pad = self.context_len - input_len + ts = np.concatenate( + [np.zeros(shape=(num_front_pad,), dtype=float), ts], axis=0 + ) + padding = np.concatenate( + [np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0 + ) + elif input_len > self.context_len: + ts = ts[-self.context_len :] + padding = padding[-(self.context_len + self.horizon_len) :] + + input_ts.append(ts) + input_padding.append(padding) + inp_freq.append(freq[i]) + + # Padding the remainder batch. + for _ in range(pmap_pad): + input_ts.append(input_ts[-1]) + input_padding.append(input_padding[-1]) + inp_freq.append(inp_freq[-1]) + + return ( + np.stack(input_ts, axis=0), + np.stack(input_padding, axis=0), + np.array(inp_freq).astype(np.int32).reshape(-1, 1), + pmap_pad, + ) + + def forecast( + self, + inputs: Sequence[Any], + freq: Sequence[int] | None = None, + window_size: int | None = None, + forecast_context_len: int | None = None, + return_forecast_on_context: bool = False, + truncate_negative: bool = False, + ) -> tuple[np.array, np.array]: + """Forecasts on a list of time series. + + Args: + inputs: list of time series forecast contexts. Each context time series + should be in a format convertible to JTensor by `jnp.array`. + freq: frequency of each context time series. 0 for high frequency + (default), 1 for medium, and 2 for low. Notice this is different from + the `freq` required by `forecast_on_df`. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + forecast_context_len: optional max context length. + return_forecast_on_context: True to return the forecast on the context + when available, i.e. after the first input patch. + truncate_negative: truncate to only non-negative values if all the contexts + have non-negative values. + + Returns: + A tuple for JTensors: + - the mean forecast of size (# inputs, # forecast horizon), + - the full forecast (mean + quantiles) of size + (# inputs, # forecast horizon, 1 + # quantiles). + + Raises: + ValueError: If the checkpoint is not properly loaded. + """ + raise NotImplementedError("`forecast` is not implemented.") + + def forecast_on_df( + self, + inputs: pd.DataFrame, + freq: str, + forecast_context_len: int = 0, + value_name: str = "values", + model_name: str = "timesfm", + window_size: int | None = None, + num_jobs: int = 1, + verbose: bool = True, + ) -> pd.DataFrame: + """Forecasts on a list of time series. + + Args: + inputs: A pd.DataFrame of all time series. The dataframe should have a + `unique_id` column for identifying the time series, a `ds` column for + timestamps and a value column for the time series values. + freq: string valued `freq` of data. Notice this is different from the + `freq` required by `forecast`. See `freq_map` for allowed values. + forecast_context_len: If provided none zero, we take the last + `forecast_context_len` time-points from each series as the forecast + context instead of the `context_len` set by the model. + value_name: The name of the value column. + model_name: name of the model to be written into future df. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + num_jobs: number of parallel processes to use for dataframe processing. + verbose: output model states in terminal. + + Returns: + Future forecasts dataframe. + """ + if not ( + "unique_id" in inputs.columns + and "ds" in inputs.columns + and value_name in inputs.columns + ): + raise ValueError( + f"DataFrame must have unique_id, ds and {value_name} columns." + ) + if not forecast_context_len: + forecast_context_len = self.context_len + logging.info("Preprocessing dataframe.") + df_sorted = inputs.sort_values(by=["unique_id", "ds"]) + new_inputs = [] + uids = [] + if num_jobs == 1: + if verbose: + print("Processing dataframe with single process.") + for key, group in df_sorted.groupby("unique_id"): + inp, uid = process_group( + key, + group, + value_name, + forecast_context_len, + ) + new_inputs.append(inp) + uids.append(uid) + else: + if num_jobs == -1: + num_jobs = multiprocessing.cpu_count() + if verbose: + print("Processing dataframe with multiple processes.") + with multiprocessing.Pool(processes=num_jobs) as pool: + results = pool.starmap( + process_group, + [ + (key, group, value_name, forecast_context_len) + for key, group in df_sorted.groupby("unique_id") + ], + ) + new_inputs, uids = zip(*results) + if verbose: + print("Finished preprocessing dataframe.") + freq_inps = [freq_map(freq)] * len(new_inputs) + _, full_forecast = self.forecast( + new_inputs, freq=freq_inps, window_size=window_size + ) + if verbose: + print("Finished forecasting.") + fcst_df = make_future_dataframe( + uids=uids, + last_times=df_sorted.groupby("unique_id")["ds"].tail(1), + h=self.horizon_len, + freq=freq, + ) + fcst_df[model_name] = full_forecast[:, 0 : self.horizon_len, 0].reshape(-1, 1) + + for i, q in enumerate(self.quantiles): + q_col = f"{model_name}-q-{q}" + fcst_df[q_col] = full_forecast[:, 0 : self.horizon_len, 1 + i].reshape( + -1, 1 + ) + if q == 0.5: + fcst_df[model_name] = fcst_df[q_col] + logging.info("Finished creating output dataframe.") + return fcst_df + + +class TimesFMModel(TimesFmBase, nn.Module): + """Body of the TimesFM model, excluding the head.""" + + def __post_init__(self): + self._model_config = TimesFMConfig( + num_layers=self.num_layers, + num_heads=self.num_heads, + hidden_size=self.model_dims, + intermediate_size=self.model_dims, + patch_len=self.input_patch_len, + horizon_len=self.output_patch_len, + head_dim=self.model_dims // self.num_heads, + quantiles=self.quantiles, + ) + + self.num_cores = 1 + self.global_batch_size = self.per_core_batch_size + self._device = torch.device( + "cuda:0" if (torch.cuda.is_available() and self.backend == "gpu") else "cpu" + ) + self._model = ppd.PatchedTimeSeriesDecoder(self._model_config) + self._model.to(self._device) + self._model.eval() + + def forecast( + self, + inputs: Sequence[Any], + freq: Sequence[int] | None = None, + window_size: int | None = None, + forecast_context_len: int | None = None, + return_forecast_on_context: bool = False, + truncate_negative: bool = False, + ) -> tuple[np.ndarray, np.ndarray]: + """Forecasts on a list of time series. + + Args: + inputs: list of time series forecast contexts. Each context time series + should be in a format convertible to JTensor by `jnp.array`. + freq: frequency of each context time series. 0 for high frequency + (default), 1 for medium, and 2 for low. Notice this is different from + the `freq` required by `forecast_on_df`. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + forecast_context_len: optional max context length. + return_forecast_on_context: True to return the forecast on the context + when available, i.e. after the first input patch. + truncate_negative: truncate to only non-negative values if all the contexts + have non-negative values. + + Returns: + A tuple for JTensors: + - the mean forecast of size (# inputs, # forecast horizon), + - the full forecast (mean + quantiles) of size + (# inputs, # forecast horizon, 1 + # quantiles). + + Raises: + ValueError: If the checkpoint is not properly loaded. + """ + if forecast_context_len is None: + fcontext_len = self.context_len + else: + fcontext_len = forecast_context_len + inputs = [np.array(ts)[-fcontext_len:] for ts in inputs] + inp_min = np.min([np.min(ts) for ts in inputs]) + + if window_size is not None: + new_inputs = [] + for ts in inputs: + new_inputs.extend(moving_average(ts, window_size)) + inputs = new_inputs + + if freq is None: + logging.info("No frequency provided via `freq`. Default to high (0).") + freq = [0] * len(inputs) + + input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq) + with torch.no_grad(): + mean_outputs = [] + full_outputs = [] + assert input_ts.shape[0] % self.global_batch_size == 0 + for i in range(input_ts.shape[0] // self.global_batch_size): + input_ts_in = torch.from_numpy( + np.array( + input_ts[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size + ], + dtype=np.float32, + ) + ).to(self._device) + input_padding_in = torch.from_numpy( + np.array( + input_padding[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size + ], + dtype=np.float32, + ) + ).to(self._device) + inp_freq_in = ( + torch.from_numpy( + np.array( + inp_freq[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size, + :, + ], + dtype=np.int32, + ) + ) + .long() + .to(self._device) + ) + mean_output, full_output = self._model.decode( + input_ts=input_ts_in, + paddings=input_padding_in, + freq=inp_freq_in, + horizon_len=self.horizon_len, + return_forecast_on_context=return_forecast_on_context, + ) + mean_output = mean_output.detach().cpu().numpy() + full_output = full_output.detach().cpu().numpy() + mean_output = np.array(mean_output) + full_output = np.array(full_output) + mean_outputs.append(mean_output) + full_outputs.append(full_output) + + mean_outputs = np.concatenate(mean_outputs, axis=0) + full_outputs = np.concatenate(full_outputs, axis=0) + + if pmap_pad > 0: + mean_outputs = mean_outputs[:-pmap_pad, ...] + full_outputs = full_outputs[:-pmap_pad, ...] + + if window_size is not None: + mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] + full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] + if inp_min >= 0 and truncate_negative: + mean_outputs = np.maximum(mean_outputs, 0.0) + full_outputs = np.maximum(full_outputs, 0.0) + return mean_outputs, full_outputs + + +class TimesFMModel(TimesFmBase, nn.Module): + """TimesFM forecast API for inference.""" + + def __init__(self, hparams: TimesFMConfig) -> None: + super.__init__(hparams) + self._model_config = hparams + self._model = ppd.PatchedTimeSeriesDecoder(self._model_config) + self.num_cores = 1 + self.global_batch_size = self.per_core_batch_size + self._device = torch.device( + "cuda:0" if (torch.cuda.is_available() and self.backend == "gpu") else "cpu" + ) + + def forecast( + self, + inputs: Sequence[Any], + freq: Sequence[int] | None = None, + window_size: int | None = None, + forecast_context_len: int | None = None, + return_forecast_on_context: bool = False, + truncate_negative: bool = False, + ) -> tuple[np.ndarray, np.ndarray]: + """Forecasts on a list of time series. + + Args: + inputs: list of time series forecast contexts. Each context time series + should be in a format convertible to JTensor by `jnp.array`. + freq: frequency of each context time series. 0 for high frequency + (default), 1 for medium, and 2 for low. Notice this is different from + the `freq` required by `forecast_on_df`. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + forecast_context_len: optional max context length. + return_forecast_on_context: True to return the forecast on the context + when available, i.e. after the first input patch. + truncate_negative: truncate to only non-negative values if all the contexts + have non-negative values. + + Returns: + A tuple for JTensors: + - the mean forecast of size (# inputs, # forecast horizon), + - the full forecast (mean + quantiles) of size + (# inputs, # forecast horizon, 1 + # quantiles). + + Raises: + ValueError: If the checkpoint is not properly loaded. + """ + if not self._model: + raise ValueError( + "Checkpoint not loaded. Call `load_from_checkpoint` before" + " `forecast`." + ) + if forecast_context_len is None: + fcontext_len = self.context_len + else: + fcontext_len = forecast_context_len + inputs = [np.array(ts)[-fcontext_len:] for ts in inputs] + inp_min = np.min([np.min(ts) for ts in inputs]) + + if window_size is not None: + new_inputs = [] + for ts in inputs: + new_inputs.extend(moving_average(ts, window_size)) + inputs = new_inputs + + if freq is None: + logging.info("No frequency provided via `freq`. Default to high (0).") + freq = [0] * len(inputs) + + input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq) + with torch.no_grad(): + mean_outputs = [] + full_outputs = [] + assert input_ts.shape[0] % self.global_batch_size == 0 + for i in range(input_ts.shape[0] // self.global_batch_size): + input_ts_in = torch.from_numpy( + np.array( + input_ts[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size + ], + dtype=np.float32, + ) + ).to(self._device) + input_padding_in = torch.from_numpy( + np.array( + input_padding[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size + ], + dtype=np.float32, + ) + ).to(self._device) + inp_freq_in = ( + torch.from_numpy( + np.array( + inp_freq[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size, + :, + ], + dtype=np.int32, + ) + ) + .long() + .to(self._device) + ) + mean_output, full_output = self._model.decode( + input_ts=input_ts_in, + paddings=input_padding_in, + freq=inp_freq_in, + horizon_len=self.horizon_len, + return_forecast_on_context=return_forecast_on_context, + ) + mean_output = mean_output.detach().cpu().numpy() + full_output = full_output.detach().cpu().numpy() + mean_output = np.array(mean_output) + full_output = np.array(full_output) + mean_outputs.append(mean_output) + full_outputs.append(full_output) + + mean_outputs = np.concatenate(mean_outputs, axis=0) + full_outputs = np.concatenate(full_outputs, axis=0) + + if pmap_pad > 0: + mean_outputs = mean_outputs[:-pmap_pad, ...] + full_outputs = full_outputs[:-pmap_pad, ...] + + if window_size is not None: + mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] + full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] + if inp_min >= 0 and truncate_negative: + mean_outputs = np.maximum(mean_outputs, 0.0) + full_outputs = np.maximum(full_outputs, 0.0) + return mean_outputs, full_outputs + + def forward(self, x, **kwargs): + if isinstance(x, pd.DataFrame): + assert "freq" in kwargs, "Frequency must be provided for DataFrame input." + return self.forecast_on_df(x, **kwargs) + else: + return self.forecast(x, **kwargs) + + +## TODO: Define the PreTrainedTimesFMModel class diff --git a/src/transformers/models/timesfm/patched_decoder.py b/src/transformers/models/timesfm/patched_decoder.py index f7e108bc08d8..baafe6be148c 100644 --- a/src/transformers/models/timesfm/patched_decoder.py +++ b/src/transformers/models/timesfm/patched_decoder.py @@ -31,22 +31,21 @@ def _masked_mean_std( It excludes values where `padding` is 1. Args: - inputs: A PyTorch tensor of shape [b, n, p]. - padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1. + inputs: A PyTorch tensor of shape [b, n, p]. + padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1. Returns: - A tuple containing the mean and standard deviation. - We return the statistics of the first patch with more than three non-padded - values. + A tuple containing the mean and standard deviation. + We return the statistics of the first patch with more than three non-padded values. """ - # Selecting the first patch with more than 3 unpadded values. - pad_sum = torch.sum(1 - padding, dim=2) + # Selecting the first patch with more than 3 unpadded values. def _get_patch_index(arr: torch.Tensor): indices = torch.argmax((arr >= 3).to(torch.int32), dim=1) row_sum = (arr >= 3).to(torch.int32).sum(dim=1) return torch.where(row_sum == 0, arr.shape[1] - 1, indices) + pad_sum = torch.sum(1 - padding, dim=2) patch_indices = _get_patch_index(pad_sum) bidxs = torch.arange(inputs.shape[0]) @@ -57,9 +56,8 @@ def _get_patch_index(arr: torch.Tensor): mask = 1 - pad # Calculate the number of valid elements - num_valid_elements = torch.sum(mask, dim=1) num_valid_elements = torch.where( - num_valid_elements == 0, + torch.sum(mask, dim=1) == 0, torch.tensor( 1, dtype=num_valid_elements.dtype, device=num_valid_elements.device ), @@ -87,11 +85,11 @@ def _shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor: """Shifts rows of seq based on the first 0 in each row of the mask. Args: - mask: mask tensor of shape [B, N] - seq: seq tensor of shape [B, N, P] + mask: mask tensor of shape [B, N] + seq: seq tensor of shape [B, N, P] Returns: - Returns the shifted sequence. + The shifted sequence. """ batch_size, num_seq, feature_dim = seq.shape diff --git a/src/transformers/models/timesfm/timesfm.py b/src/transformers/models/timesfm/timesfm.py deleted file mode 100644 index ea27c1e75b8c..000000000000 --- a/src/transformers/models/timesfm/timesfm.py +++ /dev/null @@ -1,202 +0,0 @@ -# coding=utf-8 -# Copyright 2024 Google LLC and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch TimesFM model.""" - - -#################################################### -# PyTorch Models are constructed by sub-classing -# - torch.nn.Module for the layers and -# - PreTrainedModel for the models (it-self a sub-class of nn.Module) -#################################################### - -import logging -from os import path -from typing import Any, Sequence - -import numpy as np -import torch -from huggingface_hub import snapshot_download -import timesfm_base -import patched_decoder as ppd -from ...modeling_utils import PreTrainedModel - - -_TOL = 1e-6 - - -class TimesFmTorch(PreTrainedModel, timesfm_base.TimesFmBase): - """TimesFM forecast API for inference.""" - - def __post_init__(self): - self._model_config = ppd.TimesFMConfig( - num_layers=self.num_layers, - num_heads=self.num_heads, - hidden_size=self.model_dims, - intermediate_size=self.model_dims, - patch_len=self.input_patch_len, - horizon_len=self.output_patch_len, - head_dim=self.model_dims // self.num_heads, - quantiles=self.quantiles, - ) - self._model = None - self.num_cores = 1 - self.global_batch_size = self.per_core_batch_size - self._device = torch.device( - "cuda:0" if (torch.cuda.is_available() and self.backend == "gpu") else "cpu" - ) - - def load_from_checkpoint( - self, - checkpoint: timesfm_base.TimesFmCheckpoint, - ) -> None: - """Loads a checkpoint and compiles the decoder.""" - checkpoint_path = checkpoint.path - repo_id = checkpoint.huggingface_repo_id - if checkpoint_path is None: - checkpoint_path = path.join(snapshot_download(repo_id), "torch_model.ckpt") - self._model = ppd.PatchedTimeSeriesDecoder(self._model_config) - loaded_checkpoint = torch.load(checkpoint_path, weights_only=True) - logging.info("Loading checkpoint from %s", checkpoint_path) - self._model.load_state_dict(loaded_checkpoint) - logging.info("Sending checkpoint to device %s", f"{self._device}") - self._model.to(self._device) - self._model.eval() - # TODO: add compilation. - - def forecast( - self, - inputs: Sequence[Any], - freq: Sequence[int] | None = None, - window_size: int | None = None, - forecast_context_len: int | None = None, - return_forecast_on_context: bool = False, - truncate_negative: bool = False, - ) -> tuple[np.ndarray, np.ndarray]: - """Forecasts on a list of time series. - - Args: - inputs: list of time series forecast contexts. Each context time series - should be in a format convertible to JTensor by `jnp.array`. - freq: frequency of each context time series. 0 for high frequency - (default), 1 for medium, and 2 for low. Notice this is different from - the `freq` required by `forecast_on_df`. - window_size: window size of trend + residual decomposition. If None then - we do not do decomposition. - forecast_context_len: optional max context length. - return_forecast_on_context: True to return the forecast on the context - when available, i.e. after the first input patch. - truncate_negative: truncate to only non-negative values if all the contexts - have non-negative values. - - Returns: - A tuple for JTensors: - - the mean forecast of size (# inputs, # forecast horizon), - - the full forecast (mean + quantiles) of size - (# inputs, # forecast horizon, 1 + # quantiles). - - Raises: - ValueError: If the checkpoint is not properly loaded. - """ - if not self._model: - raise ValueError( - "Checkpoint not loaded. Call `load_from_checkpoint` before" - " `forecast`." - ) - if forecast_context_len is None: - fcontext_len = self.context_len - else: - fcontext_len = forecast_context_len - inputs = [np.array(ts)[-fcontext_len:] for ts in inputs] - inp_min = np.min([np.min(ts) for ts in inputs]) - - if window_size is not None: - new_inputs = [] - for ts in inputs: - new_inputs.extend(timesfm_base.moving_average(ts, window_size)) - inputs = new_inputs - - if freq is None: - logging.info("No frequency provided via `freq`. Default to high (0).") - freq = [0] * len(inputs) - - input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq) - with torch.no_grad(): - mean_outputs = [] - full_outputs = [] - assert input_ts.shape[0] % self.global_batch_size == 0 - for i in range(input_ts.shape[0] // self.global_batch_size): - input_ts_in = torch.from_numpy( - np.array( - input_ts[ - i - * self.global_batch_size : (i + 1) - * self.global_batch_size - ], - dtype=np.float32, - ) - ).to(self._device) - input_padding_in = torch.from_numpy( - np.array( - input_padding[ - i - * self.global_batch_size : (i + 1) - * self.global_batch_size - ], - dtype=np.float32, - ) - ).to(self._device) - inp_freq_in = ( - torch.from_numpy( - np.array( - inp_freq[ - i - * self.global_batch_size : (i + 1) - * self.global_batch_size, - :, - ], - dtype=np.int32, - ) - ) - .long() - .to(self._device) - ) - mean_output, full_output = self._model.decode( - input_ts=input_ts_in, - paddings=input_padding_in, - freq=inp_freq_in, - horizon_len=self.horizon_len, - return_forecast_on_context=return_forecast_on_context, - ) - mean_output = mean_output.detach().cpu().numpy() - full_output = full_output.detach().cpu().numpy() - mean_output = np.array(mean_output) - full_output = np.array(full_output) - mean_outputs.append(mean_output) - full_outputs.append(full_output) - - mean_outputs = np.concatenate(mean_outputs, axis=0) - full_outputs = np.concatenate(full_outputs, axis=0) - - if pmap_pad > 0: - mean_outputs = mean_outputs[:-pmap_pad, ...] - full_outputs = full_outputs[:-pmap_pad, ...] - - if window_size is not None: - mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] - full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] - if inp_min >= 0 and truncate_negative: - mean_outputs = np.maximum(mean_outputs, 0.0) - full_outputs = np.maximum(full_outputs, 0.0) - return mean_outputs, full_outputs diff --git a/src/transformers/models/timesfm/timesfm_base.py b/src/transformers/models/timesfm/timesfm_base.py deleted file mode 100644 index 7c0c756e6847..000000000000 --- a/src/transformers/models/timesfm/timesfm_base.py +++ /dev/null @@ -1,340 +0,0 @@ -# coding=utf-8 -# Copyright 2024 Google LLC and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Base class for TimesFM inference. This will be common to PAX and Pytorch.""" - -import collections -import dataclasses -import logging -import multiprocessing -from typing import Any, Literal, Sequence - -import numpy as np -import pandas as pd - -from utilsforecast.processing import make_future_dataframe -from configuration_timesfm import TimesFMConfig - -_TOL = 1e-6 -DEFAULT_QUANTILES = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9) - - -def process_group(key, group, value_name, forecast_context_len): - group = group.tail(forecast_context_len) - return np.array(group[value_name], dtype=np.float32), key - - -def moving_average(arr, window_size): - """Calculates the moving average using NumPy's convolution function.""" - # Pad with zeros to handle initial window positions - arr_padded = np.pad(arr, (window_size - 1, 0), "constant") - smoothed_arr = np.convolve(arr_padded, np.ones(window_size), "valid") / window_size - return [smoothed_arr, arr - smoothed_arr] - - -def freq_map(freq: str): - """Returns the frequency map for the given frequency string.""" - freq = str.upper(freq) - if ( - freq.endswith("H") - or freq.endswith("T") - or freq.endswith("MIN") - or freq.endswith("D") - or freq.endswith("B") - or freq.endswith("U") - ): - return 0 - elif freq.endswith(("W", "M", "MS")): - return 1 - elif freq.endswith("Y") or freq.endswith("Q"): - return 2 - else: - raise ValueError(f"Invalid frequency: {freq}") - - -# Per time series normalization: forward. -def normalize(batch): - stats = [(np.mean(x), np.where((w := np.std(x)) > _TOL, w, 1.0)) for x in batch] - new_batch = [(x - stat[0]) / stat[1] for x, stat in zip(batch, stats)] - return new_batch, stats - - -# Per time series normalization: inverse. -def renormalize(batch, stats): - return [x * stat[1] + stat[0] for x, stat in zip(batch, stats)] - - -@dataclasses.dataclass(kw_only=True) -class TimesFmCheckpoint: - """Checkpoint used to initialize a TimesFM model for inference. - - Attributes: - version: Version of the checkpoint, e.g. "jax", "torch", "tensorflow", etc. - The factory will create the corresponding TimesFm inference class based on - this version. - path: Path to the checkpoint. - type: If provided, type of the checkpoint used by the specific checkpoint - loader per version. - step: If provided, step of the checkpoint. - """ - - version: str = "jax" - path: str | None = None - huggingface_repo_id: str | None = None - type: Any = None - step: int | None = None - - -class TimesFmBase: - """Base TimesFM forecast API for inference. - - This class is the scaffolding for calling TimesFM forecast. To properly use: - 1. Create an instance with the correct hyperparameters of a TimesFM model. - 2. Call `load_from_checkpoint` to load a compatible checkpoint. - 3. Call `forecast` for inference. - """ - - def _logging(self, s): - print(s) - - def __post_init__(self) -> None: - """Additional initialization for subclasses before checkpoint loading.""" - pass - - def __init__(self, hparams: TimesFMConfig, checkpoint: TimesFmCheckpoint) -> None: - """Initializes the TimesFM forecast API. - - Args: - hparams: Hyperparameters of the model. - checkpoint: Checkpoint to load. Notice `checkpoint.version` will decide - which TimesFM version to use. - """ - self.hparams = hparams - - # Expand hparams for conciseness within the model code. - self.context_len = hparams.context_len - self.horizon_len = hparams.horizon_len - self.input_patch_len = hparams.patch_len - self.output_patch_len = hparams.horizon_len - self.num_layers = hparams.num_layers - self.model_dims = hparams.model_dim - self.backend = hparams.backend - self.quantiles = hparams.quantiles - self.num_heads = hparams.num_heads - - # Rewrite these values in __post_init__ for SPMD. - self.num_cores = 1 - self.per_core_batch_size = hparams.per_core_batch_size - self.global_batch_size = hparams.per_core_batch_size - - self._horizon_start = self.context_len - self.input_patch_len - self.__post_init__() - self.load_from_checkpoint(checkpoint) - - def load_from_checkpoint(self, checkpoint: TimesFmCheckpoint) -> None: - """Loads a checkpoint and compiles the decoder.""" - raise NotImplementedError("`load_from_checkpoint` is not implemented.") - - def _preprocess( - self, inputs: Sequence[np.array], freq: Sequence[int] - ) -> tuple[np.array, np.array, int]: - """Formats and pads raw inputs to feed into the model. - - This function both pads each time series to match the context length, and - pads the inputs to meet the SPMD shape requirement. - - Args: - inputs: A list of 1d JTensors. Each JTensor is the context time series of - a single forecast task. - freq: list of frequencies - - Returns: - A tuple of: - - the padded input time series to meet the model required context. - - the padding indicator. - - the number of padded examples for SPMD so that each core has the same - number (a multiple of `batch_size`) of examples. - """ - - input_ts, input_padding, inp_freq = [], [], [] - - pmap_pad = ( - (len(inputs) - 1) // self.global_batch_size + 1 - ) * self.global_batch_size - len(inputs) - - for i, ts in enumerate(inputs): - input_len = ts.shape[0] - padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float) - if input_len < self.context_len: - num_front_pad = self.context_len - input_len - ts = np.concatenate( - [np.zeros(shape=(num_front_pad,), dtype=float), ts], axis=0 - ) - padding = np.concatenate( - [np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0 - ) - elif input_len > self.context_len: - ts = ts[-self.context_len :] - padding = padding[-(self.context_len + self.horizon_len) :] - - input_ts.append(ts) - input_padding.append(padding) - inp_freq.append(freq[i]) - - # Padding the remainder batch. - for _ in range(pmap_pad): - input_ts.append(input_ts[-1]) - input_padding.append(input_padding[-1]) - inp_freq.append(inp_freq[-1]) - - return ( - np.stack(input_ts, axis=0), - np.stack(input_padding, axis=0), - np.array(inp_freq).astype(np.int32).reshape(-1, 1), - pmap_pad, - ) - - def forecast( - self, - inputs: Sequence[Any], - freq: Sequence[int] | None = None, - window_size: int | None = None, - forecast_context_len: int | None = None, - return_forecast_on_context: bool = False, - truncate_negative: bool = False, - ) -> tuple[np.array, np.array]: - """Forecasts on a list of time series. - - Args: - inputs: list of time series forecast contexts. Each context time series - should be in a format convertible to JTensor by `jnp.array`. - freq: frequency of each context time series. 0 for high frequency - (default), 1 for medium, and 2 for low. Notice this is different from - the `freq` required by `forecast_on_df`. - window_size: window size of trend + residual decomposition. If None then - we do not do decomposition. - forecast_context_len: optional max context length. - return_forecast_on_context: True to return the forecast on the context - when available, i.e. after the first input patch. - truncate_negative: truncate to only non-negative values if all the contexts - have non-negative values. - - Returns: - A tuple for JTensors: - - the mean forecast of size (# inputs, # forecast horizon), - - the full forecast (mean + quantiles) of size - (# inputs, # forecast horizon, 1 + # quantiles). - - Raises: - ValueError: If the checkpoint is not properly loaded. - """ - raise NotImplementedError("`forecast` is not implemented.") - - def forecast_on_df( - self, - inputs: pd.DataFrame, - freq: str, - forecast_context_len: int = 0, - value_name: str = "values", - model_name: str = "timesfm", - window_size: int | None = None, - num_jobs: int = 1, - verbose: bool = True, - ) -> pd.DataFrame: - """Forecasts on a list of time series. - - Args: - inputs: A pd.DataFrame of all time series. The dataframe should have a - `unique_id` column for identifying the time series, a `ds` column for - timestamps and a value column for the time series values. - freq: string valued `freq` of data. Notice this is different from the - `freq` required by `forecast`. See `freq_map` for allowed values. - forecast_context_len: If provided none zero, we take the last - `forecast_context_len` time-points from each series as the forecast - context instead of the `context_len` set by the model. - value_name: The name of the value column. - model_name: name of the model to be written into future df. - window_size: window size of trend + residual decomposition. If None then - we do not do decomposition. - num_jobs: number of parallel processes to use for dataframe processing. - verbose: output model states in terminal. - - Returns: - Future forecasts dataframe. - """ - if not ( - "unique_id" in inputs.columns - and "ds" in inputs.columns - and value_name in inputs.columns - ): - raise ValueError( - f"DataFrame must have unique_id, ds and {value_name} columns." - ) - if not forecast_context_len: - forecast_context_len = self.context_len - logging.info("Preprocessing dataframe.") - df_sorted = inputs.sort_values(by=["unique_id", "ds"]) - new_inputs = [] - uids = [] - if num_jobs == 1: - if verbose: - print("Processing dataframe with single process.") - for key, group in df_sorted.groupby("unique_id"): - inp, uid = process_group( - key, - group, - value_name, - forecast_context_len, - ) - new_inputs.append(inp) - uids.append(uid) - else: - if num_jobs == -1: - num_jobs = multiprocessing.cpu_count() - if verbose: - print("Processing dataframe with multiple processes.") - with multiprocessing.Pool(processes=num_jobs) as pool: - results = pool.starmap( - process_group, - [ - (key, group, value_name, forecast_context_len) - for key, group in df_sorted.groupby("unique_id") - ], - ) - new_inputs, uids = zip(*results) - if verbose: - print("Finished preprocessing dataframe.") - freq_inps = [freq_map(freq)] * len(new_inputs) - _, full_forecast = self.forecast( - new_inputs, freq=freq_inps, window_size=window_size - ) - if verbose: - print("Finished forecasting.") - fcst_df = make_future_dataframe( - uids=uids, - last_times=df_sorted.groupby("unique_id")["ds"].tail(1), - h=self.horizon_len, - freq=freq, - ) - fcst_df[model_name] = full_forecast[:, 0 : self.horizon_len, 0].reshape(-1, 1) - - for i, q in enumerate(self.quantiles): - q_col = f"{model_name}-q-{q}" - fcst_df[q_col] = full_forecast[:, 0 : self.horizon_len, 1 + i].reshape( - -1, 1 - ) - if q == 0.5: - fcst_df[model_name] = fcst_df[q_col] - logging.info("Finished creating output dataframe.") - return fcst_df From 04da993a542a1a78f9d7df88a68989ca4536701e Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Tue, 1 Oct 2024 17:15:11 -0700 Subject: [PATCH 014/242] restructing in progress --- src/transformers/models/timesfm/__init__.py | 2 +- .../models/timesfm/modeling_timesfm.py | 146 ------------------ 2 files changed, 1 insertion(+), 147 deletions(-) diff --git a/src/transformers/models/timesfm/__init__.py b/src/transformers/models/timesfm/__init__.py index fe1a08da2678..91f4693ae2e5 100644 --- a/src/transformers/models/timesfm/__init__.py +++ b/src/transformers/models/timesfm/__init__.py @@ -43,7 +43,7 @@ except OptionalDependencyNotAvailable: pass else: - from .timesfm import ( + from .modeling_timesfm import ( TimesFMModel, TimesFMPreTrainedModel, ) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index f2df0c061129..64ee4d5f8af6 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -330,152 +330,6 @@ def forecast_on_df( return fcst_df -class TimesFMModel(TimesFmBase, nn.Module): - """Body of the TimesFM model, excluding the head.""" - - def __post_init__(self): - self._model_config = TimesFMConfig( - num_layers=self.num_layers, - num_heads=self.num_heads, - hidden_size=self.model_dims, - intermediate_size=self.model_dims, - patch_len=self.input_patch_len, - horizon_len=self.output_patch_len, - head_dim=self.model_dims // self.num_heads, - quantiles=self.quantiles, - ) - - self.num_cores = 1 - self.global_batch_size = self.per_core_batch_size - self._device = torch.device( - "cuda:0" if (torch.cuda.is_available() and self.backend == "gpu") else "cpu" - ) - self._model = ppd.PatchedTimeSeriesDecoder(self._model_config) - self._model.to(self._device) - self._model.eval() - - def forecast( - self, - inputs: Sequence[Any], - freq: Sequence[int] | None = None, - window_size: int | None = None, - forecast_context_len: int | None = None, - return_forecast_on_context: bool = False, - truncate_negative: bool = False, - ) -> tuple[np.ndarray, np.ndarray]: - """Forecasts on a list of time series. - - Args: - inputs: list of time series forecast contexts. Each context time series - should be in a format convertible to JTensor by `jnp.array`. - freq: frequency of each context time series. 0 for high frequency - (default), 1 for medium, and 2 for low. Notice this is different from - the `freq` required by `forecast_on_df`. - window_size: window size of trend + residual decomposition. If None then - we do not do decomposition. - forecast_context_len: optional max context length. - return_forecast_on_context: True to return the forecast on the context - when available, i.e. after the first input patch. - truncate_negative: truncate to only non-negative values if all the contexts - have non-negative values. - - Returns: - A tuple for JTensors: - - the mean forecast of size (# inputs, # forecast horizon), - - the full forecast (mean + quantiles) of size - (# inputs, # forecast horizon, 1 + # quantiles). - - Raises: - ValueError: If the checkpoint is not properly loaded. - """ - if forecast_context_len is None: - fcontext_len = self.context_len - else: - fcontext_len = forecast_context_len - inputs = [np.array(ts)[-fcontext_len:] for ts in inputs] - inp_min = np.min([np.min(ts) for ts in inputs]) - - if window_size is not None: - new_inputs = [] - for ts in inputs: - new_inputs.extend(moving_average(ts, window_size)) - inputs = new_inputs - - if freq is None: - logging.info("No frequency provided via `freq`. Default to high (0).") - freq = [0] * len(inputs) - - input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq) - with torch.no_grad(): - mean_outputs = [] - full_outputs = [] - assert input_ts.shape[0] % self.global_batch_size == 0 - for i in range(input_ts.shape[0] // self.global_batch_size): - input_ts_in = torch.from_numpy( - np.array( - input_ts[ - i - * self.global_batch_size : (i + 1) - * self.global_batch_size - ], - dtype=np.float32, - ) - ).to(self._device) - input_padding_in = torch.from_numpy( - np.array( - input_padding[ - i - * self.global_batch_size : (i + 1) - * self.global_batch_size - ], - dtype=np.float32, - ) - ).to(self._device) - inp_freq_in = ( - torch.from_numpy( - np.array( - inp_freq[ - i - * self.global_batch_size : (i + 1) - * self.global_batch_size, - :, - ], - dtype=np.int32, - ) - ) - .long() - .to(self._device) - ) - mean_output, full_output = self._model.decode( - input_ts=input_ts_in, - paddings=input_padding_in, - freq=inp_freq_in, - horizon_len=self.horizon_len, - return_forecast_on_context=return_forecast_on_context, - ) - mean_output = mean_output.detach().cpu().numpy() - full_output = full_output.detach().cpu().numpy() - mean_output = np.array(mean_output) - full_output = np.array(full_output) - mean_outputs.append(mean_output) - full_outputs.append(full_output) - - mean_outputs = np.concatenate(mean_outputs, axis=0) - full_outputs = np.concatenate(full_outputs, axis=0) - - if pmap_pad > 0: - mean_outputs = mean_outputs[:-pmap_pad, ...] - full_outputs = full_outputs[:-pmap_pad, ...] - - if window_size is not None: - mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] - full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] - if inp_min >= 0 and truncate_negative: - mean_outputs = np.maximum(mean_outputs, 0.0) - full_outputs = np.maximum(full_outputs, 0.0) - return mean_outputs, full_outputs - - class TimesFMModel(TimesFmBase, nn.Module): """TimesFM forecast API for inference.""" From e0bf022094afe8426eead390e215abf75ebc028b Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Wed, 2 Oct 2024 17:43:13 -0700 Subject: [PATCH 015/242] adapted to HF convention --- src/transformers/__init__.py | 2 - .../models/timesfm/modeling_timesfm.py | 510 +++++++++--------- .../{patched_decoder.py => timesfm_layers.py} | 254 ++------- 3 files changed, 305 insertions(+), 461 deletions(-) rename src/transformers/models/timesfm/{patched_decoder.py => timesfm_layers.py} (66%) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 428c19f845fb..9abedb65ee69 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3367,7 +3367,6 @@ ) _import_structure["models.timesfm"].extend( [ - "TimesFMForPrediction", "TimesFMModel", "TimesFMPreTrainedModel", ] @@ -7786,7 +7785,6 @@ TimeSeriesTransformerPreTrainedModel, ) from .models.timesfm import ( - TimesFMForPrediction, TimesFMModel, TimesFMPreTrainedModel, ) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 64ee4d5f8af6..adb5a9d31006 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -22,119 +22,290 @@ #################################################### -import dataclasses import logging import multiprocessing from typing import Any, Sequence -from os import path import pandas as pd import numpy as np import torch import torch.nn as nn -from huggingface_hub import snapshot_download from ...modeling_utils import PreTrainedModel +from .configuration_timesfm import TimesFMConfig +from .timesfm_layers import * -import patched_decoder as ppd +# TODO: shall remove this dependency after API design is finalized. from utilsforecast.processing import make_future_dataframe -from configuration_timesfm import TimesFMConfig - - -def process_group(key, group, value_name, forecast_context_len): - group = group.tail(forecast_context_len) - return np.array(group[value_name], dtype=np.float32), key - - -def moving_average(arr, window_size): - """Calculates the moving average using NumPy's convolution function.""" - # Pad with zeros to handle initial window positions - arr_padded = np.pad(arr, (window_size - 1, 0), "constant") - smoothed_arr = np.convolve(arr_padded, np.ones(window_size), "valid") / window_size - return [smoothed_arr, arr - smoothed_arr] - - -def freq_map(freq: str): - """Returns the frequency map for the given frequency string.""" - freq = str.upper(freq) - if ( - freq.endswith("H") - or freq.endswith("T") - or freq.endswith("MIN") - or freq.endswith("D") - or freq.endswith("B") - or freq.endswith("U") - ): - return 0 - elif freq.endswith(("W", "M", "MS")): - return 1 - elif freq.endswith("Y") or freq.endswith("Q"): - return 2 - else: - raise ValueError(f"Invalid frequency: {freq}") - - -@dataclasses.dataclass(kw_only=True) -class TimesFmCheckpoint: - """Checkpoint used to initialize a TimesFM model for inference. - - Attributes: - version: Version of the checkpoint, e.g. "jax", "torch", "tensorflow", etc. - The factory will create the corresponding TimesFm inference class based on - this version. - path: Path to the checkpoint. - type: If provided, type of the checkpoint used by the specific checkpoint - loader per version. - step: If provided, step of the checkpoint. - """ - - version: str = "torch" - path: str | None = None - huggingface_repo_id: str | None = None - type: Any = None - step: int | None = None - - -class TimesFmBase: - """Base TimesFM forecast API for inference. - - This class is the scaffolding for calling TimesFM forecast. To properly use: - 1. Create an instance with the correct hyperparameters of a TimesFM model. - 2. Call `load_from_checkpoint` to load a compatible checkpoint. - 3. Call `forecast` for inference. - """ - - def _logging(self, s): - print(s) - - def __init__(self, hparams: TimesFMConfig) -> None: - """Initializes the TimesFM forecast API. + + +class TimesFMPreTrainedModel(PreTrainedModel): + """handles the loading for all models.""" + + config_class = TimesFMConfig + base_model_prefix = "timesfm" + + def _init_weights(self, module): + if isinstance(module, nn.Embedding): + nn.init.uniform_(module.weight, a=-0.1, b=0.1) + + elif isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + + elif isinstance(module, nn.LayerNorm): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + + elif isinstance(module, RMSNorm): + nn.init.zeros_(module.weight) + + elif isinstance(module, PositionalEmbedding): + pass + + +class PatchedTimeSeriesDecoder(TimesFMPreTrainedModel): + """Patched time-series decoder.""" + + def __init__(self, config: TimesFMConfig): + super().__init__(config) + + self.config = config + self.input_ff_layer = ResidualBlock( + input_dims=2 * config.patch_len, + output_dims=config.model_dim, + hidden_dims=config.model_dim, + ) + self.freq_emb = nn.Embedding( + num_embeddings=config.freq_size, embedding_dim=config.model_dim + ) + self.horizon_ff_layer = ResidualBlock( + input_dims=config.model_dim, + output_dims=config.horizon_len * (1 + len(config.quantiles)), + hidden_dims=config.model_dim, + ) + self.stacked_transformer = StackedDecoder( + hidden_size=self.config.model_dim, + intermediate_size=self.config.model_dim, + num_heads=self.config.num_heads, + num_kv_heads=self.config.num_heads, + head_dim=self.config.head_dim, + num_layers=self.config.num_layers, + rms_norm_eps=self.config.rms_norm_eps, + ) + if self.config.use_positional_embedding: + self.position_emb = PositionalEmbedding( + embedding_dims=self.config.model_dim, + ) + + def _forward_transform( + self, inputs: torch.Tensor, patched_pads: torch.Tensor + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """Input is of shape [B, N, P].""" + mu, sigma = masked_mean_std(inputs, patched_pads) + sigma = torch.where( + sigma < self.config.tolerance, + torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device), + sigma, + ) + + # Normalize each patch + outputs = (inputs - mu[:, None, None]) / sigma[:, None, None] + outputs = torch.where( + torch.abs(inputs - self.config.pad_val) < self.config.tolerance, + torch.tensor( + self.config.pad_val, dtype=outputs.dtype, device=outputs.device + ), + outputs, + ) + return outputs, (mu, sigma) + + def _reverse_transform( + self, outputs: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor] + ) -> torch.Tensor: + """Output is of shape [B, N, P, Q].""" + mu, sigma = stats + return outputs * sigma[:, None, None, None] + mu[:, None, None, None] + + def _preprocess_input( + self, + input_ts: torch.Tensor, + input_padding: torch.Tensor, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + tuple[torch.Tensor, torch.Tensor] | None, + torch.Tensor, + ]: + """Preprocess input for stacked transformer.""" + + # Reshape into patches (using view for efficiency) + bsize = input_ts.shape[0] + patched_inputs = input_ts.view(bsize, -1, self.config.patch_len) + patched_pads = input_padding.view(bsize, -1, self.config.patch_len) + + patched_inputs = torch.where( + torch.abs(patched_pads - 1.0) < self.config.tolerance, + torch.tensor(0.0, dtype=patched_inputs.dtype, device=patched_inputs.device), + patched_inputs, + ) + patched_pads = torch.where( + torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance, + torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device), + patched_pads, + ) + patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads) + + # B x N x D + patched_inputs = patched_inputs * (1.0 - patched_pads) + concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1) + model_input = self.input_ff_layer(concat_inputs) + + # A patch should not be padded even if there is at least one zero. + patched_padding = torch.min(patched_pads, dim=-1)[ + 0 + ] # Get the values from the min result + if self.config.use_positional_embedding: + pos_emb = self.position_emb(model_input.shape[1]).to(model_input.device) + pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0) + pos_emb = shift_padded_seq(patched_padding, pos_emb) + model_input += pos_emb + + return model_input, patched_padding, stats, patched_inputs + + def _postprocess_output( + self, + model_output: torch.Tensor, + num_outputs: int, + stats: tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + """Postprocess output of stacked transformer.""" + + # B x N x (H.Q) + output_ts = self.horizon_ff_layer(model_output) + + # Reshape using view + b, n, _ = output_ts.shape + output_ts = output_ts.view(b, n, self.config.horizon_len, num_outputs) + + return self._reverse_transform(output_ts, stats) + + def forward( + self, + input_ts: torch.Tensor, + input_padding: torch.LongTensor, + freq: torch.Tensor, + ) -> torch.Tensor: + num_outputs = len(self.config.quantiles) + 1 + model_input, patched_padding, stats, _ = self._preprocess_input( + input_ts=input_ts, + input_padding=input_padding, + ) + f_emb = self.freq_emb(freq) # B x 1 x D + model_input += f_emb + model_output = self.stacked_transformer(model_input, patched_padding) + + output_ts = self._postprocess_output(model_output, num_outputs, stats) + return output_ts + + def decode( + self, + input_ts: torch.Tensor, + paddings: torch.Tensor, + freq: torch.LongTensor, + horizon_len: int, + output_patch_len: int | None = None, + max_len: int = 512, + return_forecast_on_context: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Auto-regressive decoding without caching. Args: - hparams: Hyperparameters of the model. - checkpoint: Checkpoint to load. Notice `checkpoint.version` will decide - which TimesFM version to use. + input_ts: input time-series and paddings. Time-series shape B x C. + paddings: padding shape B x (C + H) where H is the prediction length. + freq: frequency shape B x 1 + horizon_len: prediction length. + output_patch_len: output length to be fetched from one step of + auto-regressive decoding. + max_len: maximum training context length. + return_forecast_on_context: whether to return the model forecast on the + context except the first input patch. + + Returns: + Tuple of two forecasting results: + - Point (mean) output predictions as a tensor with shape B x H'. + - Full predictions (mean and quantiles) as a tensor with shape + B x H' x (1 + # quantiles). + In particular, if return_forecast_on_context is True, H' is H plus + the forecastable context length, i.e. context_len - (first) patch_len. """ - self.hparams = hparams - - # Expand hparams for conciseness within the model code. - self.context_len = hparams.context_len - self.horizon_len = hparams.horizon_len - self.input_patch_len = hparams.patch_len - self.output_patch_len = hparams.horizon_len - self.num_layers = hparams.num_layers - self.model_dims = hparams.model_dim - self.backend = hparams.backend - self.quantiles = hparams.quantiles - self.num_heads = hparams.num_heads - - # Rewrite these values in subclasses for SPMD. + final_out = input_ts + context_len = final_out.shape[1] + full_outputs = [] + if paddings.shape[1] != final_out.shape[1] + horizon_len: + raise ValueError( + "Length of paddings must match length of input + horizon_len:" + f" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}" + ) + if output_patch_len is None: + output_patch_len = self.config.horizon_len + num_decode_patches = (horizon_len + output_patch_len - 1) // output_patch_len + for step_index in range(num_decode_patches): + current_padding = paddings[:, 0 : final_out.shape[1]] + input_ts = final_out[:, -max_len:] + input_padding = current_padding[:, -max_len:] + fprop_outputs = self(input_ts, input_padding, freq) + if return_forecast_on_context and step_index == 0: + # For the first decodings step, collect the model forecast on the + # context except the unavailable first input batch forecast. + new_full_ts = fprop_outputs[:, :-1, : self.config.patch_len, :] + new_full_ts = fprop_outputs.view( + new_full_ts.size(0), -1, new_full_ts.size(3) + ) + + full_outputs.append(new_full_ts) + + # (full batch, last patch, output_patch_len, index of mean forecast = 0) + new_ts = fprop_outputs[:, -1, :output_patch_len, 0] + new_full_ts = fprop_outputs[:, -1, :output_patch_len, :] + # (full batch, last patch, output_patch_len, all output indices) + full_outputs.append(new_full_ts) + final_out = torch.concatenate([final_out, new_ts], axis=-1) + + if return_forecast_on_context: + # `full_outputs` indexing starts at after the first input patch. + full_outputs = torch.concatenate(full_outputs, axis=1)[ + :, : (context_len - self.config.patch_len + horizon_len), : + ] + else: + # `full_outputs` indexing starts at the forecast horizon. + full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0:horizon_len, :] + + return (full_outputs[:, :, 0], full_outputs) + + +class TimesFMModel(TimesFMPreTrainedModel): + def __init__(self, config: TimesFMConfig): + super().__init__(config) + + self.config = config + + self.decoder = PatchedTimeSeriesDecoder(config) + + self.context_len = config.context_len + self.horizon_len = config.horizon_len + self.input_patch_len = config.patch_len + self.output_patch_len = config.horizon_len + self.num_layers = config.num_layers + self.model_dims = config.model_dim + self.backend = config.backend + self.quantiles = config.quantiles + self.num_heads = config.num_heads + self.num_cores = 1 - self.per_core_batch_size = hparams.per_core_batch_size - self.global_batch_size = hparams.per_core_batch_size + self.per_core_batch_size = config.per_core_batch_size + self.global_batch_size = config.per_core_batch_size * self.num_cores self._horizon_start = self.context_len - self.input_patch_len - - def load_from_checkpoint(self, checkpoint: TimesFmCheckpoint) -> None: - """Loads a checkpoint and compiles the decoder.""" - raise NotImplementedError("`load_from_checkpoint` is not implemented.") + self._device = config.backend def _preprocess( self, inputs: Sequence[np.array], freq: Sequence[int] @@ -329,152 +500,9 @@ def forecast_on_df( logging.info("Finished creating output dataframe.") return fcst_df - -class TimesFMModel(TimesFmBase, nn.Module): - """TimesFM forecast API for inference.""" - - def __init__(self, hparams: TimesFMConfig) -> None: - super.__init__(hparams) - self._model_config = hparams - self._model = ppd.PatchedTimeSeriesDecoder(self._model_config) - self.num_cores = 1 - self.global_batch_size = self.per_core_batch_size - self._device = torch.device( - "cuda:0" if (torch.cuda.is_available() and self.backend == "gpu") else "cpu" - ) - - def forecast( - self, - inputs: Sequence[Any], - freq: Sequence[int] | None = None, - window_size: int | None = None, - forecast_context_len: int | None = None, - return_forecast_on_context: bool = False, - truncate_negative: bool = False, - ) -> tuple[np.ndarray, np.ndarray]: - """Forecasts on a list of time series. - - Args: - inputs: list of time series forecast contexts. Each context time series - should be in a format convertible to JTensor by `jnp.array`. - freq: frequency of each context time series. 0 for high frequency - (default), 1 for medium, and 2 for low. Notice this is different from - the `freq` required by `forecast_on_df`. - window_size: window size of trend + residual decomposition. If None then - we do not do decomposition. - forecast_context_len: optional max context length. - return_forecast_on_context: True to return the forecast on the context - when available, i.e. after the first input patch. - truncate_negative: truncate to only non-negative values if all the contexts - have non-negative values. - - Returns: - A tuple for JTensors: - - the mean forecast of size (# inputs, # forecast horizon), - - the full forecast (mean + quantiles) of size - (# inputs, # forecast horizon, 1 + # quantiles). - - Raises: - ValueError: If the checkpoint is not properly loaded. - """ - if not self._model: - raise ValueError( - "Checkpoint not loaded. Call `load_from_checkpoint` before" - " `forecast`." - ) - if forecast_context_len is None: - fcontext_len = self.context_len - else: - fcontext_len = forecast_context_len - inputs = [np.array(ts)[-fcontext_len:] for ts in inputs] - inp_min = np.min([np.min(ts) for ts in inputs]) - - if window_size is not None: - new_inputs = [] - for ts in inputs: - new_inputs.extend(moving_average(ts, window_size)) - inputs = new_inputs - - if freq is None: - logging.info("No frequency provided via `freq`. Default to high (0).") - freq = [0] * len(inputs) - - input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq) - with torch.no_grad(): - mean_outputs = [] - full_outputs = [] - assert input_ts.shape[0] % self.global_batch_size == 0 - for i in range(input_ts.shape[0] // self.global_batch_size): - input_ts_in = torch.from_numpy( - np.array( - input_ts[ - i - * self.global_batch_size : (i + 1) - * self.global_batch_size - ], - dtype=np.float32, - ) - ).to(self._device) - input_padding_in = torch.from_numpy( - np.array( - input_padding[ - i - * self.global_batch_size : (i + 1) - * self.global_batch_size - ], - dtype=np.float32, - ) - ).to(self._device) - inp_freq_in = ( - torch.from_numpy( - np.array( - inp_freq[ - i - * self.global_batch_size : (i + 1) - * self.global_batch_size, - :, - ], - dtype=np.int32, - ) - ) - .long() - .to(self._device) - ) - mean_output, full_output = self._model.decode( - input_ts=input_ts_in, - paddings=input_padding_in, - freq=inp_freq_in, - horizon_len=self.horizon_len, - return_forecast_on_context=return_forecast_on_context, - ) - mean_output = mean_output.detach().cpu().numpy() - full_output = full_output.detach().cpu().numpy() - mean_output = np.array(mean_output) - full_output = np.array(full_output) - mean_outputs.append(mean_output) - full_outputs.append(full_output) - - mean_outputs = np.concatenate(mean_outputs, axis=0) - full_outputs = np.concatenate(full_outputs, axis=0) - - if pmap_pad > 0: - mean_outputs = mean_outputs[:-pmap_pad, ...] - full_outputs = full_outputs[:-pmap_pad, ...] - - if window_size is not None: - mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] - full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] - if inp_min >= 0 and truncate_negative: - mean_outputs = np.maximum(mean_outputs, 0.0) - full_outputs = np.maximum(full_outputs, 0.0) - return mean_outputs, full_outputs - def forward(self, x, **kwargs): if isinstance(x, pd.DataFrame): assert "freq" in kwargs, "Frequency must be provided for DataFrame input." return self.forecast_on_df(x, **kwargs) else: return self.forecast(x, **kwargs) - - -## TODO: Define the PreTrainedTimesFMModel class diff --git a/src/transformers/models/timesfm/patched_decoder.py b/src/transformers/models/timesfm/timesfm_layers.py similarity index 66% rename from src/transformers/models/timesfm/patched_decoder.py rename to src/transformers/models/timesfm/timesfm_layers.py index baafe6be148c..713623eb98a7 100644 --- a/src/transformers/models/timesfm/patched_decoder.py +++ b/src/transformers/models/timesfm/timesfm_layers.py @@ -17,13 +17,13 @@ import math from typing import List, Tuple +import numpy as np import torch from torch import nn import torch.nn.functional as F -from transformers.models.timesfm.configuration_timesfm import TimesFMConfig -def _masked_mean_std( +def masked_mean_std( inputs: torch.Tensor, padding: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """Calculates mean and standard deviation of `inputs` across axis 1. @@ -81,7 +81,7 @@ def _get_patch_index(arr: torch.Tensor): return masked_mean, masked_std -def _shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor: +def shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor: """Shifts rows of seq based on the first 0 in each row of the mask. Args: @@ -213,6 +213,39 @@ def expand_t(key_mask): return torch.minimum(a, b) # Element-wise minimum, similar to jnp.minimum +def process_group(key, group, value_name, forecast_context_len): + group = group.tail(forecast_context_len) + return np.array(group[value_name], dtype=np.float32), key + + +def moving_average(arr, window_size): + """Calculates the moving average using NumPy's convolution function.""" + # Pad with zeros to handle initial window positions + arr_padded = np.pad(arr, (window_size - 1, 0), "constant") + smoothed_arr = np.convolve(arr_padded, np.ones(window_size), "valid") / window_size + return [smoothed_arr, arr - smoothed_arr] + + +def freq_map(freq: str): + """Returns the frequency map for the given frequency string.""" + freq = str.upper(freq) + if ( + freq.endswith("H") + or freq.endswith("T") + or freq.endswith("MIN") + or freq.endswith("D") + or freq.endswith("B") + or freq.endswith("U") + ): + return 0 + elif freq.endswith(("W", "M", "MS")): + return 1 + elif freq.endswith("Y") or freq.endswith("Q"): + return 2 + else: + raise ValueError(f"Invalid frequency: {freq}") + + class ResidualBlock(nn.Module): """TimesFM residual block.""" @@ -547,218 +580,3 @@ def forward(self, seq_length=None, position=None): # Padding to ensure correct embedding dimension signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) return signal - - -class PatchedTimeSeriesDecoder(nn.Module): - """Patched time-series decoder.""" - - def __init__(self, config: TimesFMConfig): - super().__init__() - self.config = config - self.input_ff_layer = ResidualBlock( - input_dims=2 * config.patch_len, - output_dims=config.model_dim, - hidden_dims=config.model_dim, - ) - self.freq_emb = nn.Embedding(num_embeddings=3, embedding_dim=config.model_dim) - self.horizon_ff_layer = ResidualBlock( - input_dims=config.model_dim, - output_dims=config.horizon_len * (1 + len(config.quantiles)), - hidden_dims=config.model_dim, - ) - self.stacked_transformer = StackedDecoder( - hidden_size=self.config.model_dim, - intermediate_size=self.config.model_dim, - num_heads=self.config.num_heads, - num_kv_heads=self.config.num_heads, - head_dim=self.config.head_dim, - num_layers=self.config.num_layers, - rms_norm_eps=self.config.rms_norm_eps, - ) - if self.config.use_positional_embedding: - self.position_emb = PositionalEmbedding(self.config.model_dim) - - def _forward_transform( - self, inputs: torch.Tensor, patched_pads: torch.Tensor - ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - """Input is of shape [B, N, P].""" - mu, sigma = _masked_mean_std(inputs, patched_pads) - sigma = torch.where( - sigma < self.config.tolerance, - torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device), - sigma, - ) - - # Normalize each patch - outputs = (inputs - mu[:, None, None]) / sigma[:, None, None] - outputs = torch.where( - torch.abs(inputs - self.config.pad_val) < self.config.tolerance, - torch.tensor( - self.config.pad_val, dtype=outputs.dtype, device=outputs.device - ), - outputs, - ) - return outputs, (mu, sigma) - - def _reverse_transform( - self, outputs: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor] - ) -> torch.Tensor: - """Output is of shape [B, N, P, Q].""" - mu, sigma = stats - return outputs * sigma[:, None, None, None] + mu[:, None, None, None] - - def _preprocess_input( - self, - input_ts: torch.Tensor, - input_padding: torch.Tensor, - ) -> tuple[ - torch.Tensor, - torch.Tensor, - tuple[torch.Tensor, torch.Tensor] | None, - torch.Tensor, - ]: - """Preprocess input for stacked transformer.""" - - # Reshape into patches (using view for efficiency) - bsize = input_ts.shape[0] - patched_inputs = input_ts.view(bsize, -1, self.config.patch_len) - patched_pads = input_padding.view(bsize, -1, self.config.patch_len) - - patched_inputs = torch.where( - torch.abs(patched_pads - 1.0) < self.config.tolerance, - torch.tensor(0.0, dtype=patched_inputs.dtype, device=patched_inputs.device), - patched_inputs, - ) - patched_pads = torch.where( - torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance, - torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device), - patched_pads, - ) - patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads) - - # B x N x D - patched_inputs = patched_inputs * (1.0 - patched_pads) - concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1) - model_input = self.input_ff_layer(concat_inputs) - - # A patch should not be padded even if there is at least one zero. - patched_padding = torch.min(patched_pads, dim=-1)[ - 0 - ] # Get the values from the min result - if self.config.use_positional_embedding: - pos_emb = self.position_emb(model_input.shape[1]).to(model_input.device) - pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0) - pos_emb = _shift_padded_seq(patched_padding, pos_emb) - model_input += pos_emb - - return model_input, patched_padding, stats, patched_inputs - - def _postprocess_output( - self, - model_output: torch.Tensor, - num_outputs: int, - stats: tuple[torch.Tensor, torch.Tensor], - ) -> torch.Tensor: - """Postprocess output of stacked transformer.""" - - # B x N x (H.Q) - output_ts = self.horizon_ff_layer(model_output) - - # Reshape using view - b, n, _ = output_ts.shape - output_ts = output_ts.view(b, n, self.config.horizon_len, num_outputs) - - return self._reverse_transform(output_ts, stats) - - def forward( - self, - input_ts: torch.Tensor, - input_padding: torch.LongTensor, - freq: torch.Tensor, - ) -> torch.Tensor: - num_outputs = len(self.config.quantiles) + 1 - model_input, patched_padding, stats, _ = self._preprocess_input( - input_ts=input_ts, - input_padding=input_padding, - ) - f_emb = self.freq_emb(freq) # B x 1 x D - model_input += f_emb - model_output = self.stacked_transformer(model_input, patched_padding) - - output_ts = self._postprocess_output(model_output, num_outputs, stats) - return output_ts - - def decode( - self, - input_ts: torch.Tensor, - paddings: torch.Tensor, - freq: torch.LongTensor, - horizon_len: int, - output_patch_len: int | None = None, - max_len: int = 512, - return_forecast_on_context: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Auto-regressive decoding without caching. - - Args: - input_ts: input time-series and paddings. Time-series shape B x C. - paddings: padding shape B x (C + H) where H is the prediction length. - freq: frequency shape B x 1 - horizon_len: prediction length. - output_patch_len: output length to be fetched from one step of - auto-regressive decoding. - max_len: maximum training context length. - return_forecast_on_context: whether to return the model forecast on the - context except the first input patch. - - Returns: - Tuple of two forecasting results: - - Point (mean) output predictions as a tensor with shape B x H'. - - Full predictions (mean and quantiles) as a tensor with shape - B x H' x (1 + # quantiles). - In particular, if return_forecast_on_context is True, H' is H plus - the forecastable context length, i.e. context_len - (first) patch_len. - """ - final_out = input_ts - context_len = final_out.shape[1] - full_outputs = [] - if paddings.shape[1] != final_out.shape[1] + horizon_len: - raise ValueError( - "Length of paddings must match length of input + horizon_len:" - f" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}" - ) - if output_patch_len is None: - output_patch_len = self.config.horizon_len - num_decode_patches = (horizon_len + output_patch_len - 1) // output_patch_len - for step_index in range(num_decode_patches): - current_padding = paddings[:, 0 : final_out.shape[1]] - input_ts = final_out[:, -max_len:] - input_padding = current_padding[:, -max_len:] - fprop_outputs = self(input_ts, input_padding, freq) - if return_forecast_on_context and step_index == 0: - # For the first decodings step, collect the model forecast on the - # context except the unavailable first input batch forecast. - new_full_ts = fprop_outputs[:, :-1, : self.config.patch_len, :] - new_full_ts = fprop_outputs.view( - new_full_ts.size(0), -1, new_full_ts.size(3) - ) - - full_outputs.append(new_full_ts) - - # (full batch, last patch, output_patch_len, index of mean forecast = 0) - new_ts = fprop_outputs[:, -1, :output_patch_len, 0] - new_full_ts = fprop_outputs[:, -1, :output_patch_len, :] - # (full batch, last patch, output_patch_len, all output indices) - full_outputs.append(new_full_ts) - final_out = torch.concatenate([final_out, new_ts], axis=-1) - - if return_forecast_on_context: - # `full_outputs` indexing starts at after the first input patch. - full_outputs = torch.concatenate(full_outputs, axis=1)[ - :, : (context_len - self.config.patch_len + horizon_len), : - ] - else: - # `full_outputs` indexing starts at the forecast horizon. - full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0:horizon_len, :] - - return (full_outputs[:, :, 0], full_outputs) From 729a57631fad63e4a07d8fb87f16ca1704aca9d5 Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Wed, 9 Oct 2024 17:48:35 -0700 Subject: [PATCH 016/242] timesfm test --- .../models/timesfm/modeling_timesfm.py | 195 ++- tests/models/timesfm/test_modeling_timesfm.py | 1347 +---------------- 2 files changed, 170 insertions(+), 1372 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index adb5a9d31006..612493e7f7dd 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -23,7 +23,6 @@ import logging -import multiprocessing from typing import Any, Sequence import pandas as pd import numpy as np @@ -33,9 +32,6 @@ from .configuration_timesfm import TimesFMConfig from .timesfm_layers import * -# TODO: shall remove this dependency after API design is finalized. -from utilsforecast.processing import make_future_dataframe - class TimesFMPreTrainedModel(PreTrainedModel): """handles the loading for all models.""" @@ -366,7 +362,7 @@ def _preprocess( pmap_pad, ) - def forecast( + def forward( self, inputs: Sequence[Any], freq: Sequence[int] | None = None, @@ -374,7 +370,7 @@ def forecast( forecast_context_len: int | None = None, return_forecast_on_context: bool = False, truncate_negative: bool = False, - ) -> tuple[np.array, np.array]: + ) -> tuple[np.ndarray, np.ndarray]: """Forecasts on a list of time series. Args: @@ -400,109 +396,90 @@ def forecast( Raises: ValueError: If the checkpoint is not properly loaded. """ - raise NotImplementedError("`forecast` is not implemented.") - - def forecast_on_df( - self, - inputs: pd.DataFrame, - freq: str, - forecast_context_len: int = 0, - value_name: str = "values", - model_name: str = "timesfm", - window_size: int | None = None, - num_jobs: int = 1, - verbose: bool = True, - ) -> pd.DataFrame: - """Forecasts on a list of time series. - Args: - inputs: A pd.DataFrame of all time series. The dataframe should have a - `unique_id` column for identifying the time series, a `ds` column for - timestamps and a value column for the time series values. - freq: string valued `freq` of data. Notice this is different from the - `freq` required by `forecast`. See `freq_map` for allowed values. - forecast_context_len: If provided none zero, we take the last - `forecast_context_len` time-points from each series as the forecast - context instead of the `context_len` set by the model. - value_name: The name of the value column. - model_name: name of the model to be written into future df. - window_size: window size of trend + residual decomposition. If None then - we do not do decomposition. - num_jobs: number of parallel processes to use for dataframe processing. - verbose: output model states in terminal. - - Returns: - Future forecasts dataframe. - """ - if not ( - "unique_id" in inputs.columns - and "ds" in inputs.columns - and value_name in inputs.columns - ): - raise ValueError( - f"DataFrame must have unique_id, ds and {value_name} columns." - ) - if not forecast_context_len: - forecast_context_len = self.context_len - logging.info("Preprocessing dataframe.") - df_sorted = inputs.sort_values(by=["unique_id", "ds"]) - new_inputs = [] - uids = [] - if num_jobs == 1: - if verbose: - print("Processing dataframe with single process.") - for key, group in df_sorted.groupby("unique_id"): - inp, uid = process_group( - key, - group, - value_name, - forecast_context_len, - ) - new_inputs.append(inp) - uids.append(uid) + if forecast_context_len is None: + fcontext_len = self.context_len else: - if num_jobs == -1: - num_jobs = multiprocessing.cpu_count() - if verbose: - print("Processing dataframe with multiple processes.") - with multiprocessing.Pool(processes=num_jobs) as pool: - results = pool.starmap( - process_group, - [ - (key, group, value_name, forecast_context_len) - for key, group in df_sorted.groupby("unique_id") - ], + fcontext_len = forecast_context_len + inputs = [np.array(ts)[-fcontext_len:] for ts in inputs] + inp_min = np.min([np.min(ts) for ts in inputs]) + + if window_size is not None: + new_inputs = [] + for ts in inputs: + new_inputs.extend(moving_average(ts, window_size)) + inputs = new_inputs + + if freq is None: + logging.info("No frequency provided via `freq`. Default to high (0).") + freq = [0] * len(inputs) + + input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq) + with torch.no_grad(): + mean_outputs = [] + full_outputs = [] + assert input_ts.shape[0] % self.global_batch_size == 0 + for i in range(input_ts.shape[0] // self.global_batch_size): + input_ts_in = torch.from_numpy( + np.array( + input_ts[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size + ], + dtype=np.float32, + ) + ).to(self._device) + input_padding_in = torch.from_numpy( + np.array( + input_padding[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size + ], + dtype=np.float32, + ) + ).to(self._device) + inp_freq_in = ( + torch.from_numpy( + np.array( + inp_freq[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size, + :, + ], + dtype=np.int32, + ) + ) + .long() + .to(self._device) ) - new_inputs, uids = zip(*results) - if verbose: - print("Finished preprocessing dataframe.") - freq_inps = [freq_map(freq)] * len(new_inputs) - _, full_forecast = self.forecast( - new_inputs, freq=freq_inps, window_size=window_size - ) - if verbose: - print("Finished forecasting.") - fcst_df = make_future_dataframe( - uids=uids, - last_times=df_sorted.groupby("unique_id")["ds"].tail(1), - h=self.horizon_len, - freq=freq, - ) - fcst_df[model_name] = full_forecast[:, 0 : self.horizon_len, 0].reshape(-1, 1) - - for i, q in enumerate(self.quantiles): - q_col = f"{model_name}-q-{q}" - fcst_df[q_col] = full_forecast[:, 0 : self.horizon_len, 1 + i].reshape( - -1, 1 - ) - if q == 0.5: - fcst_df[model_name] = fcst_df[q_col] - logging.info("Finished creating output dataframe.") - return fcst_df - - def forward(self, x, **kwargs): - if isinstance(x, pd.DataFrame): - assert "freq" in kwargs, "Frequency must be provided for DataFrame input." - return self.forecast_on_df(x, **kwargs) - else: - return self.forecast(x, **kwargs) + mean_output, full_output = self.decoder.decode( + input_ts=input_ts_in, + paddings=input_padding_in, + freq=inp_freq_in, + horizon_len=self.horizon_len, + return_forecast_on_context=return_forecast_on_context, + ) + mean_output = mean_output.detach().cpu().numpy() + full_output = full_output.detach().cpu().numpy() + mean_output = np.array(mean_output) + full_output = np.array(full_output) + mean_outputs.append(mean_output) + full_outputs.append(full_output) + + mean_outputs = np.concatenate(mean_outputs, axis=0) + full_outputs = np.concatenate(full_outputs, axis=0) + + if pmap_pad > 0: + mean_outputs = mean_outputs[:-pmap_pad, ...] + full_outputs = full_outputs[:-pmap_pad, ...] + + if window_size is not None: + mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] + full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] + if inp_min >= 0 and truncate_negative: + mean_outputs = np.maximum(mean_outputs, 0.0) + full_outputs = np.maximum(full_outputs, 0.0) + return mean_outputs, full_outputs diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index e08277fac50f..2aafe4133c8d 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -14,11 +14,9 @@ # limitations under the License. -import copy -import os -import pickle -import tempfile +import numpy as np import unittest +from typing import List from transformers import TimesFMConfig, is_torch_available from transformers.testing_utils import ( @@ -33,7 +31,7 @@ from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor +from ...test_modeling_common import ModelTesterMixin from ...test_pipeline_mixin import PipelineTesterMixin @@ -46,10 +44,7 @@ from transformers import ( AutoTokenizer, - ByT5Tokenizer, - TimesFMForPrediction, TimesFMModel, - T5Tokenizer, ) @@ -57,1295 +52,121 @@ class TimesFMModelTester: def __init__( self, parent, - vocab_size=99, - batch_size=13, - encoder_seq_length=7, - decoder_seq_length=7, - # For common tests - is_training=True, - use_attention_mask=True, - use_labels=True, - hidden_size=32, - num_hidden_layers=2, - num_attention_heads=4, - d_ff=37, - relative_attention_num_buckets=8, - dropout_rate=0.1, - initializer_factor=0.002, - eos_token_id=1, - pad_token_id=0, - decoder_start_token_id=0, - scope=None, - decoder_layers=None, + patch_len: int = 32, + context_len: int = 512, + horizon_len: int = 128, + freq_size: int = 3, + num_layers: int = 20, + model_dim: int = 1280, + head_dim: int = 80, + num_heads: int = 16, + dropout_rate: float = 0.1, + tolerance: float = 1e-6, + rms_norm_eps: float = 1e-6, + quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + pad_val: float = 1123581321.0, + use_positional_embedding: bool = True, + per_core_batch_size: int = 32, + initializer_factor: float = 1.0, + backend: str = "gpu", ): self.parent = parent - self.batch_size = batch_size - self.encoder_seq_length = encoder_seq_length - self.decoder_seq_length = decoder_seq_length - # For common tests - self.seq_length = self.decoder_seq_length - self.is_training = is_training - self.use_attention_mask = use_attention_mask - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.d_ff = d_ff - self.relative_attention_num_buckets = relative_attention_num_buckets + self.patch_len = patch_len + self.context_len = context_len + self.horizon_len = horizon_len + self.quantiles = quantiles + self.pad_val = pad_val + self.freq_size = freq_size + self.model_dim = model_dim + self.head_dim = head_dim + self.num_layers = num_layers + self.num_heads = num_heads self.dropout_rate = dropout_rate + self.tolerance = tolerance + self.rms_norm_eps = rms_norm_eps + self.use_positional_embedding = use_positional_embedding + self.per_core_batch_size = per_core_batch_size self.initializer_factor = initializer_factor - self.eos_token_id = eos_token_id - self.pad_token_id = pad_token_id - self.decoder_start_token_id = decoder_start_token_id - self.scope = None - self.decoder_layers = decoder_layers + self.backend = backend def get_large_model_config(self): - return TimesFMConfig.from_pretrained("google-timesfm/timesfm-base") - - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size).clamp(2) - input_ids[:, -1] = self.eos_token_id # Eos Token - decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) - - attention_mask = None - decoder_attention_mask = None - if self.use_attention_mask: - attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) - decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2) - - lm_labels = None - if self.use_labels: - lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) - - config = self.get_config() - - return ( - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ) - - def get_pipeline_config(self): - return TimesFMConfig( - vocab_size=166, # timesfm forces 100 extra tokens - d_model=self.hidden_size, - d_ff=self.d_ff, - d_kv=self.hidden_size // self.num_attention_heads, - num_layers=self.num_hidden_layers, - num_decoder_layers=self.decoder_layers, - num_heads=self.num_attention_heads, - relative_attention_num_buckets=self.relative_attention_num_buckets, - dropout_rate=self.dropout_rate, - initializer_factor=self.initializer_factor, - eos_token_id=self.eos_token_id, - bos_token_id=self.pad_token_id, - pad_token_id=self.pad_token_id, - decoder_start_token_id=self.decoder_start_token_id, - ) + return TimesFMConfig.from_pretrained("google/timesfm-1.0-200m-pytorch") def get_config(self): return TimesFMConfig( - vocab_size=self.vocab_size, - d_model=self.hidden_size, - d_ff=self.d_ff, - d_kv=self.hidden_size // self.num_attention_heads, - num_layers=self.num_hidden_layers, - num_decoder_layers=self.decoder_layers, - num_heads=self.num_attention_heads, - relative_attention_num_buckets=self.relative_attention_num_buckets, + patch_len=self.patch_len, + context_len=self.context_len, + horizon_len=self.horizon_len, + quantiles=self.quantiles, + pad_val=self.pad_val, + freq_size=self.freq_size, + model_dim=self.model_dim, + head_dim=self.head_dim, + num_layers=self.num_layers, + num_heads=self.num_heads, dropout_rate=self.dropout_rate, + tolerance=self.tolerance, + rms_norm_eps=self.rms_norm_eps, + use_positional_embedding=self.use_positional_embedding, + per_core_batch_size=self.per_core_batch_size, initializer_factor=self.initializer_factor, - eos_token_id=self.eos_token_id, - bos_token_id=self.pad_token_id, - pad_token_id=self.pad_token_id, - decoder_start_token_id=self.decoder_start_token_id, + backend=self.backend, ) - def check_prepare_lm_labels_via_shift_left( - self, - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ): - model = TimesFMModel(config=config) - model.to(torch_device) - model.eval() - - # make sure that lm_labels are correctly padded from the right - lm_labels.masked_fill_((lm_labels == self.decoder_start_token_id), self.eos_token_id) - - # add casaul pad token mask - triangular_mask = torch.tril(lm_labels.new_ones(lm_labels.shape)).logical_not() - lm_labels.masked_fill_(triangular_mask, self.pad_token_id) - decoder_input_ids = model._shift_right(lm_labels) - - for i, (decoder_input_ids_slice, lm_labels_slice) in enumerate(zip(decoder_input_ids, lm_labels)): - # first item - self.parent.assertEqual(decoder_input_ids_slice[0].item(), self.decoder_start_token_id) - if i < decoder_input_ids_slice.shape[-1]: - if i < decoder_input_ids.shape[-1] - 1: - # items before diagonal - self.parent.assertListEqual( - decoder_input_ids_slice[1 : i + 1].tolist(), lm_labels_slice[:i].tolist() - ) - # pad items after diagonal - if i < decoder_input_ids.shape[-1] - 2: - self.parent.assertListEqual( - decoder_input_ids_slice[i + 2 :].tolist(), lm_labels_slice[i + 1 : -1].tolist() - ) - else: - # all items after square - self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist()) - - def create_and_check_model( - self, - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ): - model = TimesFMModel(config=config) - model.to(torch_device) - model.eval() - result = model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - ) - result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) - decoder_output = result.last_hidden_state - decoder_past = result.past_key_values - encoder_output = result.encoder_last_hidden_state - - self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size)) - self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size)) - # There should be `num_layers` key value embeddings stored in decoder_past - self.parent.assertEqual(len(decoder_past), config.num_layers) - # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple - self.parent.assertEqual(len(decoder_past[0]), 4) - - def create_and_check_with_lm_head( - self, - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ): - model = TimesFMForPrediction(config=config).to(torch_device).eval() - outputs = model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - labels=lm_labels, - ) - self.parent.assertEqual(len(outputs), 4) - self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size)) - self.parent.assertEqual(outputs["loss"].size(), ()) - - def create_and_check_decoder_model_past( - self, - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ): - model = TimesFMModel(config=config).get_decoder().to(torch_device).eval() - # first forward pass - outputs = model(input_ids, use_cache=True) - outputs_use_cache_conf = model(input_ids) - outputs_no_past = model(input_ids, use_cache=False) - - self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) - self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) - - output, past_key_values = outputs.to_tuple() - - # create hypothetical next token and extent to next_input_ids - next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) - - # append to next input_ids and - next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) - - output_from_no_past = model(next_input_ids)["last_hidden_state"] - output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] - - # select random slice - random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() - output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() - output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() - - # test that outputs are equal for slice - self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) - - def create_and_check_decoder_model_attention_mask_past( - self, - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ): - model = TimesFMModel(config=config).get_decoder() - model.to(torch_device) - model.eval() - - # create attention mask - attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) - - half_seq_length = input_ids.shape[-1] // 2 - attn_mask[:, half_seq_length:] = 0 - - # first forward pass - output, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True).to_tuple() - - # create hypothetical next token and extent to next_input_ids - next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) - - # change a random masked slice from input_ids - random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1 - random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) - input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens - - # append to next input_ids and attn_mask - next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) - attn_mask = torch.cat( - [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], - dim=1, - ) - - # get two different outputs - output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] - output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[ - "last_hidden_state" - ] - - # select random slice - random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() - output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() - output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() - - # test that outputs are equal for slice - self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) - - def create_and_check_decoder_model_past_large_inputs( - self, - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ): - model = TimesFMModel(config=config).get_decoder().to(torch_device).eval() - # first forward pass - outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) - - output, past_key_values = outputs.to_tuple() - - # create hypothetical multiple next token and extent to next_input_ids - next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) - next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) - - # append to next input_ids and - next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) - next_attention_mask = torch.cat([attention_mask, next_mask], dim=-1) + def get_pipeline_config(self): + return self.get_config() - output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"] - output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[ - "last_hidden_state" + def prepare_config_and_inputs(self): + forecast_input = [ + np.sin(np.linspace(0, 20, 100)), + np.sin(np.linspace(0, 20, 200)), + np.sin(np.linspace(0, 20, 400)), ] + frequency_input = [0, 1, 2] - # select random slice - random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() - output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() - output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() - - self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) - - # test that outputs are equal for slice - self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + config = self.get_config() - def create_and_check_generate_with_past_key_values( - self, - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ): - model = TimesFMForPrediction(config=config).to(torch_device).eval() - torch.manual_seed(0) - output_without_past_cache = model.generate( - input_ids[:1], num_beams=2, max_length=5, do_sample=True, use_cache=False + return ( + config, + forecast_input, + frequency_input, ) - torch.manual_seed(0) - output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True) - self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache)) - - def create_and_check_model_fp16_forward( - self, - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ): - model = TimesFMModel(config=config).to(torch_device).half().eval() - output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"] - self.parent.assertFalse(torch.isnan(output).any().item()) - - def create_and_check_encoder_decoder_shared_weights( - self, - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ): - for model_class in [TimesFMModel, TimesFMForPrediction]: - torch.manual_seed(0) - model = model_class(config=config).to(torch_device).eval() - # load state dict copies weights but does not tie them - model.encoder.load_state_dict(model.decoder.state_dict(), strict=False) - - torch.manual_seed(0) - tied_config = copy.deepcopy(config) - tied_config.tie_encoder_decoder = True - tied_model = model_class(config=tied_config).to(torch_device).eval() - - model_result = model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - ) - - tied_model_result = tied_model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - ) - - # check that models has less parameters - self.parent.assertLess( - sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) - ) - random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() - - # check that outputs are equal - self.parent.assertTrue( - torch.allclose( - model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4 - ) - ) - - # check that outputs after saving and loading are equal - with tempfile.TemporaryDirectory() as tmpdirname: - tied_model.save_pretrained(tmpdirname) - tied_model = model_class.from_pretrained(tmpdirname) - tied_model.to(torch_device) - tied_model.eval() - - # check that models has less parameters - self.parent.assertLess( - sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) - ) - random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() - - tied_model_result = tied_model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - ) - - # check that outputs are equal - self.parent.assertTrue( - torch.allclose( - model_result[0][0, :, random_slice_idx], - tied_model_result[0][0, :, random_slice_idx], - atol=1e-4, - ) - ) - - def check_resize_embeddings_timesfm_v1_1( - self, - config, - ): - prev_vocab_size = config.vocab_size - - config.tie_word_embeddings = False - model = TimesFMForPrediction(config=config).to(torch_device).eval() - model.resize_token_embeddings(prev_vocab_size - 10) - - self.parent.assertEqual(model.get_input_embeddings().weight.shape[0], prev_vocab_size - 10) - self.parent.assertEqual(model.get_output_embeddings().weight.shape[0], prev_vocab_size - 10) - self.parent.assertEqual(model.config.vocab_size, prev_vocab_size - 10) def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() ( config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ) = config_and_inputs + forecast_input, + frequency_input, + ) = self.prepare_config_and_inputs() inputs_dict = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "decoder_input_ids": decoder_input_ids, - "decoder_attention_mask": decoder_attention_mask, - "use_cache": False, + "inputs": forecast_input, + "freq": frequency_input, } return config, inputs_dict @require_torch -class TimesFMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = ( - (TimesFMModel, TimesFMForPrediction) - if is_torch_available() - else () - ) - all_generative_model_classes = (TimesFMForPrediction,) if is_torch_available() else () - all_parallelizable_model_classes = (TimesFMModel, TimesFMForPrediction) if is_torch_available() else () +class TimesFMModelTest( + ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase +): + all_model_classes = (TimesFMModel,) if is_torch_available() else () + all_generative_model_classes = (TimesFMModel,) if is_torch_available() else () + all_parallelizable_model_classes = () fx_compatible = False test_pruning = False - test_resize_embeddings = True - test_model_parallel = True - is_encoder_decoder = True - # The small TimesFM model needs higher percentages for CPU/MP tests - model_split_percents = [0.5, 0.8, 0.9] + test_resize_embeddings = False + test_model_parallel = False + is_encoder_decoder = False def setUp(self): self.model_tester = TimesFMModelTester(self) - self.config_tester = ConfigTester(self, config_class=TimesFMConfig, d_model=37) + self.config_tester = ConfigTester(self, config_class=TimesFMConfig) - # TimesFMForSequenceClassification does not support inputs_embeds - def test_inputs_embeds(self): + def test_create_and_run_model(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in (TimesFMModel, TimesFMForPrediction): - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) - - if not self.is_encoder_decoder: - input_ids = inputs["input_ids"] - del inputs["input_ids"] - else: - encoder_input_ids = inputs["input_ids"] - decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) - del inputs["input_ids"] - inputs.pop("decoder_input_ids", None) - - wte = model.get_input_embeddings() - if not self.is_encoder_decoder: - inputs["inputs_embeds"] = wte(input_ids) - else: - inputs["inputs_embeds"] = wte(encoder_input_ids) - inputs["decoder_inputs_embeds"] = wte(decoder_input_ids) - - with torch.no_grad(): - model(**inputs)[0] - - def test_config_and_model_silu_gated(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - config = config_and_inputs[0] - config.feed_forward_proj = "gated-silu" - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_with_lm_head(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_with_lm_head(*config_and_inputs) - - def test_decoder_model_past(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_decoder_model_past(*config_and_inputs) - - def test_decoder_model_past_with_attn_mask(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) - - def test_decoder_model_past_with_3d_attn_mask(self): - ( - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ) = self.model_tester.prepare_config_and_inputs() - - attention_mask = ids_tensor( - [self.model_tester.batch_size, self.model_tester.encoder_seq_length, self.model_tester.encoder_seq_length], - vocab_size=2, - ) - decoder_attention_mask = ids_tensor( - [self.model_tester.batch_size, self.model_tester.decoder_seq_length, self.model_tester.decoder_seq_length], - vocab_size=2, - ) - - self.model_tester.create_and_check_decoder_model_attention_mask_past( - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ) - - def test_decoder_model_past_with_large_inputs(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) - - def test_generate_with_past_key_values(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_generate_with_past_key_values(*config_and_inputs) - - def test_encoder_decoder_shared_weights(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_encoder_decoder_shared_weights(*config_and_inputs) - - @unittest.skipIf(torch_device == "cpu", "Cant do half precision") - def test_model_fp16_forward(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) - - def test_v1_1_resize_embeddings(self): - config = self.model_tester.prepare_config_and_inputs()[0] - self.model_tester.check_resize_embeddings_timesfm_v1_1(config) - - @slow - def test_model_from_pretrained(self): - model_name = "google/timesfm-1.0-200m" - model = TimesFMModel.from_pretrained(model_name) - self.assertIsNotNone(model) - - @unittest.skip(reason="Test has a segmentation fault on torch 1.8.0") - def test_export_to_onnx(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - model = TimesFMModel(config_and_inputs[0]).to(torch_device) - with tempfile.TemporaryDirectory() as tmpdirname: - torch.onnx.export( - model, - (config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]), - f"{tmpdirname}/timesfm_test.onnx", - export_params=True, - opset_version=9, - input_names=["input_ids", "decoder_input_ids"], - ) - - def test_generate_with_head_masking(self): - attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - config = config_and_inputs[0] - max_length = config_and_inputs[1].shape[-1] + 3 - model = TimesFMForPrediction(config).eval() + model = TimesFMModel(config) model.to(torch_device) - - head_masking = { - "head_mask": torch.zeros(config.num_layers, config.num_heads, device=torch_device), - "decoder_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device), - "cross_attn_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device), - } - - for attn_name, (name, mask) in zip(attention_names, head_masking.items()): - head_masks = {name: mask} - # Explicitly pass decoder_head_mask as it is required from TimesFM model when head_mask specified - if name == "head_mask": - head_masks["decoder_head_mask"] = torch.ones( - config.num_decoder_layers, config.num_heads, device=torch_device - ) - - out = model.generate( - config_and_inputs[1], - num_beams=1, - max_length=max_length, - output_attentions=True, - return_dict_in_generate=True, - **head_masks, - ) - # We check the state of decoder_attentions and cross_attentions just from the last step - attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1] - self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0) - - -class TimesFMEncoderOnlyModelTester: - def __init__( - self, - parent, - vocab_size=99, - batch_size=13, - encoder_seq_length=7, - # For common tests - use_attention_mask=True, - hidden_size=32, - num_hidden_layers=2, - num_attention_heads=4, - d_ff=37, - relative_attention_num_buckets=8, - is_training=False, - dropout_rate=0.1, - initializer_factor=0.002, - is_encoder_decoder=False, - eos_token_id=1, - pad_token_id=0, - scope=None, - ): - self.parent = parent - self.batch_size = batch_size - self.encoder_seq_length = encoder_seq_length - # For common tests - self.seq_length = self.encoder_seq_length - self.use_attention_mask = use_attention_mask - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.d_ff = d_ff - self.relative_attention_num_buckets = relative_attention_num_buckets - self.dropout_rate = dropout_rate - self.initializer_factor = initializer_factor - self.eos_token_id = eos_token_id - self.pad_token_id = pad_token_id - self.is_encoder_decoder = is_encoder_decoder - self.scope = None - self.is_training = is_training - - def get_large_model_config(self): - return TimesFMConfig.from_pretrained("google-timesfm/timesfm-base") - - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) - - attention_mask = None - if self.use_attention_mask: - attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) - - config = TimesFMConfig( - vocab_size=self.vocab_size, - d_model=self.hidden_size, - d_ff=self.d_ff, - d_kv=self.hidden_size // self.num_attention_heads, - num_layers=self.num_hidden_layers, - num_heads=self.num_attention_heads, - relative_attention_num_buckets=self.relative_attention_num_buckets, - dropout_rate=self.dropout_rate, - initializer_factor=self.initializer_factor, - eos_token_id=self.eos_token_id, - bos_token_id=self.pad_token_id, - pad_token_id=self.pad_token_id, - is_encoder_decoder=self.is_encoder_decoder, - ) - - return ( - config, - input_ids, - attention_mask, - ) - - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - attention_mask, - ) = config_and_inputs - - inputs_dict = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - return config, inputs_dict - - -def use_task_specific_params(model, task): - model.config.update(model.config.task_specific_params[task]) - - -@require_torch -@require_accelerate -@require_tokenizers -@slow -class TimesFMModelFp16Tests(unittest.TestCase): - def test_fp16_fp32_conversion(self): - r""" - A test to check whether the argument `keep_in_fp32_modules` correctly does its job - """ - orig_import = __import__ - accelerate_mock = unittest.mock.Mock() - - # mock import of accelerate - def import_accelerate_mock(name, *args, **kwargs): - if name == "accelerate": - if accelerate_available: - return accelerate_mock - else: - raise ImportError - return orig_import(name, *args, **kwargs) - - # Load without using `accelerate` - with unittest.mock.patch("builtins.__import__", side_effect=import_accelerate_mock): - accelerate_available = False - - model = TimesFMForPrediction.from_pretrained("google/timesfm-1.0-200m", torch_dtype=torch.float16) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16) - - # Load without in bf16 - model = TimesFMForPrediction.from_pretrained("google/timesfm-1.0-200m", torch_dtype=torch.bfloat16) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) - - # Load using `accelerate` in bf16 - model = TimesFMForPrediction.from_pretrained( - "google/timesfm-1.0-200m", torch_dtype=torch.bfloat16, device_map="auto" - ) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) - - # Load using `accelerate` in bf16 - model = TimesFMForPrediction.from_pretrained( - "google/timesfm-1.0-200m", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True - ) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) - - # Load without using `accelerate` - model = TimesFMForPrediction.from_pretrained( - "google/timesfm-1.0-200m", torch_dtype=torch.float16, low_cpu_mem_usage=True - ) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16) - - # Load using `accelerate` - model = TimesFMForPrediction.from_pretrained( - "google/timesfm-1.0-200m", torch_dtype=torch.float16, device_map="auto" - ) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16) - - -@require_torch -@require_sentencepiece -@require_tokenizers -class TimesFMModelIntegrationTests(unittest.TestCase): - @cached_property - def model(self): - return TimesFMForPrediction.from_pretrained("google-timesfm/timesfm-base").to(torch_device) - - @cached_property - def tokenizer(self): - return T5Tokenizer.from_pretrained("google-timesfm/timesfm-base") - - @slow - def test_torch_quant(self): - r""" - Test that a simple `torch.quantization.quantize_dynamic` call works on a TimesFM model. - """ - model_name = "google/flan-timesfm-small" - tokenizer = T5Tokenizer.from_pretrained(model_name) - model = TimesFMForPrediction.from_pretrained(model_name) - model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) - input_text = "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?" - input_ids = tokenizer(input_text, return_tensors="pt").input_ids - _ = model.generate(input_ids) - - @slow - def test_small_generation(self): - model = TimesFMForPrediction.from_pretrained("google/timesfm-1.0-200m").to(torch_device) - model.config.max_length = 8 - model.config.num_beams = 1 - model.config.do_sample = False - tokenizer = T5Tokenizer.from_pretrained("google/timesfm-1.0-200m") - - input_ids = tokenizer("summarize: Hello there", return_tensors="pt").input_ids.to(torch_device) - - sequences = model.generate(input_ids) - - output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] - self.assertTrue(output_str == "Hello there!") - - @slow - def test_small_integration_test(self): - """ - For comparision run: - >>> import timesfm # pip install timesfm==0.7.1 - >>> from timesfm.data.sentencepiece_vocabulary import SentencePieceVocabulary - - >>> path_to_mtf_small_timesfm_checkpoint = '' - >>> path_to_mtf_small_spm_model_path = '' - >>> timesfm_model = timesfm.models.MtfModel(model_dir=path_to_mtf_small_timesfm_checkpoint, batch_size=1, tpu=None) - >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) - >>> score = timesfm_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) - """ - - model = TimesFMForPrediction.from_pretrained("google/timesfm-1.0-200m").to(torch_device) - tokenizer = T5Tokenizer.from_pretrained("google/timesfm-1.0-200m") - - input_ids = tokenizer("Hello there", return_tensors="pt").input_ids - labels = tokenizer("Hi I am", return_tensors="pt").input_ids - - loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss - mtf_score = -(labels.shape[-1] * loss.item()) - - EXPECTED_SCORE = -19.0845 - self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) - - @slow - def test_small_v1_1_integration_test(self): - """ - For comparision run: - >>> import timesfm # pip install timesfm==0.7.1 - >>> from timesfm.data.sentencepiece_vocabulary import SentencePieceVocabulary - - >>> path_to_mtf_small_timesfm_v1_1_checkpoint = '' - >>> path_to_mtf_small_spm_model_path = '' - >>> timesfm_model = timesfm.models.MtfModel(model_dir=path_to_mtf_small_timesfm_v1_1_checkpoint, batch_size=1, tpu=None) - >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) - >>> score = timesfm_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) - """ - - model = TimesFMForPrediction.from_pretrained("google/timesfm-v1_1-small").to(torch_device) - tokenizer = T5Tokenizer.from_pretrained("google/timesfm-v1_1-small") - - input_ids = tokenizer("Hello there", return_tensors="pt").input_ids - labels = tokenizer("Hi I am", return_tensors="pt").input_ids - - loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss - mtf_score = -(labels.shape[-1] * loss.item()) - - EXPECTED_SCORE = -59.0293 - self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) - - @slow - def test_small_bytimesfm_integration_test(self): - """ - For comparision run: - >>> import timesfm # pip install timesfm==0.9.1 - - >>> path_to_bytimesfm_small_checkpoint = '' - >>> timesfm_model = timesfm.models.MtfModel(model_dir=path_to_tf_checkpoint, batch_size=1, tpu=None) - >>> vocab = timesfm.data.ByteVocabulary() - >>> score = timesfm_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) - """ - - model = TimesFMForPrediction.from_pretrained("google/bytimesfm-small").to(torch_device) - tokenizer = ByT5Tokenizer.from_pretrained("google/bytimesfm-small") - - input_ids = tokenizer("Hello there", return_tensors="pt").input_ids - labels = tokenizer("Hi I am", return_tensors="pt").input_ids - - loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss - mtf_score = -(labels.shape[-1] * loss.item()) - - EXPECTED_SCORE = -60.7397 - self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) - - @slow - def test_summarization(self): - model = self.model - tok = self.tokenizer - - FRANCE_ARTICLE = ( # @noqa - "Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings" - " Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane." - ' Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation."' - ' He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s' - " comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video" - " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French" - " Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a" - " phone at the wreckage site. The two publications described the supposed video, but did not post it on" - " their websites. The publications said that they watched the video, which was found by a source close to" - " the investigation. \"One can hear cries of 'My God' in several languages,\" Paris Match reported." - ' "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the' - " cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the" - ' screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt,' - " editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said" - " the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman" - " in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the" - ' reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said,' - ' but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be' - " sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by" - " specialized technicians working hand-in-hand with investigators. But none of the cell phones found so" - " far have been sent to the institute, Menichini said. Asked whether staff involved in the search could" - ' have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin' - ' Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match' - ' are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered' - ' cell phones from the crash site after Bild and Paris Match published their reports. "That is something' - " we did not know before. ... Overall we can say many things of the investigation weren't revealed by the" - ' investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline' - " Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the" - " controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the" - ' French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of' - ' severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school' - " discovered in an internal investigation, Lufthansa said, included medical documents he submitted in" - " connection with resuming his flight training. The announcement indicates that Lufthansa, the parent" - " company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and" - " ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100%" - ' fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was' - " sharing the information and documents -- including training and medical records -- with public" - " prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the" - " past week to recover human remains and plane debris scattered across a steep mountainside. He saw the" - " crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash" - " site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late" - " Tuesday that no visible human remains were left at the site but recovery teams would keep searching." - " French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all" - " the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested." - " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said." - " Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew" - " on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with" - " the flight school during his training were among several developments as investigators continued to" - " delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa" - " spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his" - ' examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in' - " Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at" - " some point before his aviation career and underwent psychotherapy before he got his pilot's license." - " Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the" - " crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to" - " lose his pilot's license, a European government official briefed on the investigation told CNN on" - ' Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being' - " considered. Another source, a law enforcement official briefed on the investigation, also told CNN that" - " authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would" - " not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had" - " seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded" - " he had psychological issues, the European government official said. But no matter what details emerge" - " about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic" - ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact' - " that maybe they weren't going to keep doing their job and they're upset about that and so they're" - ' suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to' - " also take that rage and turn it outward on 149 other people who had nothing to do with the person's" - ' problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight' - " 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura" - " Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine" - " Amiel and Anna-Maja Rappard contributed to this report." - ) - SHORTER_ARTICLE = ( - "(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on" - " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The" - " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based." - " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its" - ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East' - ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the' - " situation in Palestinian territories, paving the way for possible war crimes investigations against" - " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and" - " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the" - " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a" - ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the' - ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an' - ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge' - " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the" - ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine' - " acquires all the rights as well as responsibilities that come with being a State Party to the Statute." - ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights' - ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should' - " immediately end their pressure, and countries that support universal acceptance of the court's treaty" - ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the' - " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's" - ' decision to join a treaty to which over 100 countries around the world are members." In January, when' - " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an" - ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"' - " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a" - ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in' - ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We' - ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"' - " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the" - ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the' - " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou" - ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war' - " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry" - " will include alleged war crimes committed since June. The International Criminal Court was set up in" - " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder" - " and Faith Karimi contributed to this report." - ) - IRAN_ARTICLE = ( - "(CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran" - " in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively" - " block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger." - " Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli" - " Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a" - " letter to the Iranian leadership warning them away from a deal. The debate that has already begun since" - " the announcement of the new framework will likely result in more heat than light. It will not be helped" - " by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: ." - " The most misleading assertion, despite universal rejection by experts, is that the negotiations'" - " objective at the outset was the total elimination of any nuclear program in Iran. That is the position" - " of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it" - " had been, there would have been no Iranian team at the negotiating table. Rather, the objective has" - " always been to structure an agreement or series of agreements so that Iran could not covertly develop a" - " nuclear arsenal before the United States and its allies could respond. The new framework has exceeded" - " expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by" - " two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another" - " dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite" - " sharp accusations by some in the United States and its allies, Iran denies having such a program, and" - " U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's" - " continued cooperation with International Atomic Energy Agency inspections is further evidence on this" - " point, and we'll know even more about Iran's program in the coming months and years because of the deal." - " In fact, the inspections provisions that are part of this agreement are designed to protect against any" - " covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that" - " the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter" - " warning that a deal might be killed by Congress or a future president). This of course is not the case." - " The talks were between Iran and the five permanent members of the U.N. Security Council (United States," - " United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has" - " played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement" - " reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran" - " and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement" - " contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the" - " case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased" - " or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes" - " Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear" - " sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going" - " forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such" - " a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the" - ' agreement should be a formal treaty requiring the Senate to "advise and consent." But the issue is not' - " suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New" - " START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement" - " with Iran will not be so balanced. The restrictions and obligations in the final framework agreement" - " will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove" - " most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally" - " some insist that any agreement must address Iranian missile programs, human rights violations or support" - " for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are" - " unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in" - " the negotiations would be a poison pill. This agreement should be judged on its merits and on how it" - " affects the security of our negotiating partners and allies, including Israel. Those judgments should be" - " fact-based, not based on questionable assertions or dubious assumptions." - ) - ARTICLE_SUBWAY = ( - "New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A" - " year later, she got married again in Westchester County, but to a different man and without divorcing" - " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos" - ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married' - " once more, this time in the Bronx. In an application for a marriage license, she stated it was her" - ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false' - ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage' - " license application, according to court documents. Prosecutors said the marriages were part of an" - " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to" - " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was" - " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New" - " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total," - " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All" - " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be" - " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors" - " said the immigration scam involved some of her husbands, who filed for permanent residence status" - " shortly after the marriages. Any divorces happened only after such filings were approved. It was" - " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District" - " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's" - ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,' - " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his" - " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces" - " up to four years in prison. Her next court appearance is scheduled for May 18." - ) - - expected_summaries = [ - 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a' - " cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one" - " magazine says .", - "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a" - " preliminary examination into the situation in the occupied Palestinian territory . as members of the" - " court, Palestinians may be subject to counter-charges as well .", - "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller:" - " the debate that has already begun since the announcement of the new framework will likely result in more" - " heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and" - " implement a rigorous inspection regime .", - "prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two" - ' criminal counts of "offering a false instrument for filing in the first degree" she has been married 10' - " times, with nine of her marriages occurring between 1999 and 2002 .", - ] - - use_task_specific_params(model, "summarization") - - dct = tok( - [model.config.prefix + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]], - padding="max_length", - truncation=True, - return_tensors="pt", - ).to(torch_device) - self.assertEqual(512, dct["input_ids"].shape[1]) - - hypotheses_batch = model.generate( - **dct, - num_beams=4, - length_penalty=2.0, - max_length=142, - min_length=56, - no_repeat_ngram_size=3, - do_sample=False, - early_stopping=True, - ) - - decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False) - self.assertListEqual( - expected_summaries, - decoded, - ) - - @slow - def test_translation_en_to_de(self): - model = self.model - tok = self.tokenizer - use_task_specific_params(model, "translation_en_to_de") - - en_text = '"Luigi often said to me that he never wanted the brothers to end up in court", she wrote.' - expected_translation = ( - '"Luigi sagte mir oft, dass er nie wollte, dass die Brüder am Gericht sitzen", schrieb sie.' - ) - - input_ids = tok.encode(model.config.prefix + en_text, return_tensors="pt") - input_ids = input_ids.to(torch_device) - output = model.generate(input_ids) - translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) - self.assertEqual(translation, expected_translation) - - @slow - def test_translation_en_to_fr(self): - model = self.model # google-timesfm/timesfm-base - tok = self.tokenizer - use_task_specific_params(model, "translation_en_to_fr") - - en_text = ( - ' This image section from an infrared recording by the Spitzer telescope shows a "family portrait" of' - " countless generations of stars: the oldest stars are seen as blue dots. " - ) - - input_ids = tok.encode(model.config.prefix + en_text, return_tensors="pt") - input_ids = input_ids.to(torch_device) - - output = model.generate( - input_ids=input_ids, - num_beams=4, - length_penalty=2.0, - max_length=100, - no_repeat_ngram_size=3, - do_sample=False, - early_stopping=True, - ) - translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) - new_truncated_translation = ( - "Cette section d'images provenant de l'enregistrement infrarouge effectué par le télescope Spitzer montre " - "un " - "« portrait familial » de générations innombrables d’étoiles : les plus anciennes sont observées " - "sous forme " - "de points bleus." - ) - - self.assertEqual(translation, new_truncated_translation) - - @slow - def test_translation_en_to_ro(self): - model = self.model - tok = self.tokenizer - use_task_specific_params(model, "translation_en_to_ro") - en_text = "Taco Bell said it plans to add 2,000 locations in the US by 2022." - expected_translation = "Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022." - - inputs = tok(model.config.prefix + en_text, return_tensors="pt").to(torch_device) - output = model.generate(**inputs) - translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) - self.assertEqual(translation, expected_translation) - - @slow - def test_contrastive_search_timesfm(self): - article = ( - " New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A" - " year later, she got married again in Westchester County, but to a different man and without divorcing" - " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos" - ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married' - " once more, this time in the Bronx. In an application for a marriage license, she stated it was her" - ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false' - ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage' - " license application, according to court documents. Prosecutors said the marriages were part of an" - " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to" - " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was" - " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New" - " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total," - " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All" - " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be" - " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors" - " said the immigration scam involved some of her husbands, who filed for permanent residence status" - " shortly after the marriages. Any divorces happened only after such filings were approved. It was" - " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District" - " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's" - ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,' - " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his" - " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces" - " up to four years in prison. Her next court appearance is scheduled for May 18." - ) - article = "summarize: " + article.strip() - timesfm_tokenizer = AutoTokenizer.from_pretrained("flax-community/timesfm-base-cnn-dm") - timesfm_model = TimesFMForPrediction.from_pretrained("flax-community/timesfm-base-cnn-dm").to(torch_device) - input_ids = timesfm_tokenizer( - article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt" - ).input_ids.to(torch_device) - - outputs = timesfm_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64) - generated_text = timesfm_tokenizer.batch_decode(outputs, skip_special_tokens=True) - - self.assertListEqual( - generated_text, - [ - "Liana Barrientos has been married 10 times, nine of them in the Bronx. Her husbands filed for " - "permanent residence after the marriages, prosecutors say." - ], - ) - - -@require_torch -class TestAsymmetricTimesFM(unittest.TestCase): - def build_model_and_check_forward_pass(self, **kwargs): - tester = TimesFMModelTester(self, **kwargs) - config, *inputs = tester.prepare_config_and_inputs() - ( - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ) = inputs - model = TimesFMForPrediction(config=config).to(torch_device).eval() - outputs = model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - labels=lm_labels, - ) - # outputs = model(*inputs) - assert len(outputs) == 4 - assert outputs["logits"].size() == (tester.batch_size, tester.decoder_seq_length, tester.vocab_size) - assert outputs["loss"].size() == () - return model - - def test_small_decoder(self): - # num_hidden_layers is passed to TimesFMConfig as num_layers - model = self.build_model_and_check_forward_pass(decoder_layers=1, num_hidden_layers=2) - assert len(model.encoder.block) == 2 - assert len(model.decoder.block) == 1 - - def test_defaulting_to_symmetry(self): - # num_hidden_layers is passed to TimesFMConfig as num_layers - model = self.build_model_and_check_forward_pass(num_hidden_layers=2) - assert len(model.decoder.block) == len(model.encoder.block) == 2 + model.eval() + results = model.run_model(**inputs_dict) + assert results From c06bbe1cf6d31697d0bfb3e027ef2d84f2de5314 Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Thu, 10 Oct 2024 17:07:48 -0700 Subject: [PATCH 017/242] the model runs --- src/transformers/models/auto/modeling_auto.py | 4 +-- src/transformers/models/timesfm/__init__.py | 2 +- .../models/timesfm/configuration_timesfm.py | 2 -- .../models/timesfm/modeling_timesfm.py | 30 ++++++++----------- .../models/timesfm/timesfm_layers.py | 3 +- tests/models/timesfm/test_modeling_timesfm.py | 3 -- 6 files changed, 16 insertions(+), 28 deletions(-) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index f63bb0fcd548..10fe8935496e 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -352,7 +352,7 @@ ("squeezebert", "SqueezeBertForMaskedLM"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), - ("timesfm", "TimesFMForPrediction"), + ("timesfm", "TimesFMModel"), ("tapas", "TapasForMaskedLM"), ("transfo-xl", "TransfoXLLMHeadModel"), ("tvlt", "TvltForPreTraining"), @@ -446,7 +446,6 @@ ("squeezebert", "SqueezeBertForMaskedLM"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), - ("timesfm", "TimesFMForPrediction"), ("tapas", "TapasForMaskedLM"), ("transfo-xl", "TransfoXLLMHeadModel"), ("wav2vec2", "Wav2Vec2ForMaskedLM"), @@ -870,7 +869,6 @@ ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToText"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), - ("timesfm", "TimesFMForPrediction"), ("umt5", "UMT5ForConditionalGeneration"), ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"), ] diff --git a/src/transformers/models/timesfm/__init__.py b/src/transformers/models/timesfm/__init__.py index 91f4693ae2e5..82bbb6be22ce 100644 --- a/src/transformers/models/timesfm/__init__.py +++ b/src/transformers/models/timesfm/__init__.py @@ -29,7 +29,7 @@ except OptionalDependencyNotAvailable: pass else: - _import_structure["timesfm"] = [ + _import_structure["modeling_timesfm"] = [ "TimesFMModel", "TimesFMPreTrainedModel", ] diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 29948593aff5..69397214c7ce 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -101,7 +101,6 @@ def __init__( use_positional_embedding: bool = True, per_core_batch_size: int = 32, initializer_factor: float = 1.0, - backend: str = "gpu", **kwargs, ): self.patch_len = patch_len @@ -120,7 +119,6 @@ def __init__( self.use_positional_embedding = use_positional_embedding self.per_core_batch_size = per_core_batch_size self.initializer_factor = initializer_factor - self.backend = backend super().__init__( **kwargs, diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 612493e7f7dd..ba1c2da9c1c8 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -293,7 +293,6 @@ def __init__(self, config: TimesFMConfig): self.output_patch_len = config.horizon_len self.num_layers = config.num_layers self.model_dims = config.model_dim - self.backend = config.backend self.quantiles = config.quantiles self.num_heads = config.num_heads @@ -301,7 +300,6 @@ def __init__(self, config: TimesFMConfig): self.per_core_batch_size = config.per_core_batch_size self.global_batch_size = config.per_core_batch_size * self.num_cores self._horizon_start = self.context_len - self.input_patch_len - self._device = config.backend def _preprocess( self, inputs: Sequence[np.array], freq: Sequence[int] @@ -429,7 +427,7 @@ def forward( ], dtype=np.float32, ) - ).to(self._device) + ) input_padding_in = torch.from_numpy( np.array( input_padding[ @@ -439,22 +437,18 @@ def forward( ], dtype=np.float32, ) - ).to(self._device) - inp_freq_in = ( - torch.from_numpy( - np.array( - inp_freq[ - i - * self.global_batch_size : (i + 1) - * self.global_batch_size, - :, - ], - dtype=np.int32, - ) - ) - .long() - .to(self._device) ) + inp_freq_in = torch.from_numpy( + np.array( + inp_freq[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size, + :, + ], + dtype=np.int32, + ) + ).long() mean_output, full_output = self.decoder.decode( input_ts=input_ts_in, paddings=input_padding_in, diff --git a/src/transformers/models/timesfm/timesfm_layers.py b/src/transformers/models/timesfm/timesfm_layers.py index 713623eb98a7..0ba6f2c6f54d 100644 --- a/src/transformers/models/timesfm/timesfm_layers.py +++ b/src/transformers/models/timesfm/timesfm_layers.py @@ -56,8 +56,9 @@ def _get_patch_index(arr: torch.Tensor): mask = 1 - pad # Calculate the number of valid elements + num_valid_elements = torch.sum(mask, dim=1) num_valid_elements = torch.where( - torch.sum(mask, dim=1) == 0, + num_valid_elements == 0, torch.tensor( 1, dtype=num_valid_elements.dtype, device=num_valid_elements.device ), diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index 2aafe4133c8d..ebf5df4d2cf3 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -68,7 +68,6 @@ def __init__( use_positional_embedding: bool = True, per_core_batch_size: int = 32, initializer_factor: float = 1.0, - backend: str = "gpu", ): self.parent = parent self.patch_len = patch_len @@ -87,7 +86,6 @@ def __init__( self.use_positional_embedding = use_positional_embedding self.per_core_batch_size = per_core_batch_size self.initializer_factor = initializer_factor - self.backend = backend def get_large_model_config(self): return TimesFMConfig.from_pretrained("google/timesfm-1.0-200m-pytorch") @@ -110,7 +108,6 @@ def get_config(self): use_positional_embedding=self.use_positional_embedding, per_core_batch_size=self.per_core_batch_size, initializer_factor=self.initializer_factor, - backend=self.backend, ) def get_pipeline_config(self): From 18439762ec5eff4510c248d32bbaee7bec08b6c6 Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Thu, 24 Oct 2024 09:50:18 -0700 Subject: [PATCH 018/242] fixing unit tests --- .../models/timesfm/configuration_timesfm.py | 23 ++-- .../models/timesfm/modeling_timesfm.py | 101 ++++++++++++-- .../models/timesfm/timesfm_layers.py | 20 ++- tests/models/timesfm/test_modeling_timesfm.py | 124 +++++++++++++++--- 4 files changed, 224 insertions(+), 44 deletions(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 69397214c7ce..6e6dc8aec307 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -37,30 +37,28 @@ class TimesFMConfig(PretrainedConfig): Arguments: patch_len (`int`, *optional*, defaults to 32): The length of one patch in the input sequence. - horizon_len (`int`, *optional*, defaults to 128): - The length of the prediction horizon. context_len (`int`, *optional*, defaults to 512): The length of the input context. + horizon_len (`int`, *optional*, defaults to 128): + The length of the prediction horizon. freq_size (`int`, *optional*, defaults to 3): The number of frequency embeddings. + num_layers (`int`, *optional*, defaults to 20): + Number of Transformer layers. model_dim (`int`, *optional*, defaults to 1280): Size of the hidden layers in the feed-forward networks. head_dim (`int`, *optional*, defaults to 80): Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will be defined as `num_heads * head_dim`. - num_layers (`int`, *optional*, defaults to 20): - Number of Transformer layers. num_heads (`int`, *optional*, defaults to 16): Number of attention heads for each attention layer in the Transformer encoder. - tolerance (`float`, *optional*, defaults to 1e-6): - The tolerance for the quantile loss. dropout_rate (`float`, *optional*, defaults to 0.1): The ratio for all dropout layers. - classifier_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for classifier. - rms_norm_eps (`float`, *optional*, defaults to 1e-6): + tolerance (`float`, *optional*, defaults to 1e-06): + The tolerance for the quantile loss. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the RMS normalization layers. - quantiles (`List[float]`, *optional*, defaults to `[0.1, 0.25, 0.5, 0.75, 0.9]`): + quantiles (`List[float]`, *optional*, defaults to `[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]`): The quantiles to predict. pad_val (`float`, *optional*, defaults to 1123581321.0): The value used to pad the predictions. @@ -68,11 +66,9 @@ class TimesFMConfig(PretrainedConfig): Whether to add positional embeddings. per_core_batch_size (`int`, *optional*, defaults to 32): The batch size per core for data parallelism. - initializer_factor (`float`, *optional*, defaults to 1): + initializer_factor (`float`, *optional*, defaults to 1.0): A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing). - backend (`str`, *optional*, defaults to `"gpu"`): - The backend to use for the model. Can be either `"gpu"` or `"cpu"`. """ model_type = "timesfm" @@ -82,6 +78,7 @@ class TimesFMConfig(PretrainedConfig): "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers", } + is_encoder_decoder = False def __init__( self, diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index ba1c2da9c1c8..8757647075f6 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -23,21 +23,31 @@ import logging +from dataclasses import dataclass from typing import Any, Sequence -import pandas as pd + import numpy as np import torch import torch.nn as nn + +from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from .configuration_timesfm import TimesFMConfig from .timesfm_layers import * +@dataclass +class TimesFMOutput(BaseModelOutput): + mean_predictions: np.ndarray = None + full_predictions: np.ndarray = None + + class TimesFMPreTrainedModel(PreTrainedModel): """handles the loading for all models.""" config_class = TimesFMConfig base_model_prefix = "timesfm" + main_input_name = "inputs" def _init_weights(self, module): if isinstance(module, nn.Embedding): @@ -153,7 +163,9 @@ def _preprocess_input( # B x N x D patched_inputs = patched_inputs * (1.0 - patched_pads) + print(">>> PatchedDecoder patched_inputs", patched_inputs.shape) concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1) + print(">>> PatchedDecoder concat_inputs", concat_inputs.shape) model_input = self.input_ff_layer(concat_inputs) # A patch should not be padded even if there is at least one zero. @@ -190,7 +202,10 @@ def forward( input_ts: torch.Tensor, input_padding: torch.LongTensor, freq: torch.Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, ) -> torch.Tensor: + print(">>> PatchedDecoder input_ts", input_ts.shape) num_outputs = len(self.config.quantiles) + 1 model_input, patched_padding, stats, _ = self._preprocess_input( input_ts=input_ts, @@ -198,10 +213,14 @@ def forward( ) f_emb = self.freq_emb(freq) # B x 1 x D model_input += f_emb - model_output = self.stacked_transformer(model_input, patched_padding) + + print(">>> PatchedDecoder model_input", model_input.shape) + model_output, all_attentions, all_hidden_states = self.stacked_transformer(model_input, patched_padding, output_attentions=output_attentions, output_hidden_states=output_hidden_states) + if output_hidden_states: + all_hidden_states = [model_input] + all_hidden_states output_ts = self._postprocess_output(model_output, num_outputs, stats) - return output_ts + return output_ts, all_attentions, all_hidden_states def decode( self, @@ -212,7 +231,9 @@ def decode( output_patch_len: int | None = None, max_len: int = 512, return_forecast_on_context: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor]: + output_attentions: bool = False, + output_hidden_states: bool = False, + ): """Auto-regressive decoding without caching. Args: @@ -249,7 +270,7 @@ def decode( current_padding = paddings[:, 0 : final_out.shape[1]] input_ts = final_out[:, -max_len:] input_padding = current_padding[:, -max_len:] - fprop_outputs = self(input_ts, input_padding, freq) + fprop_outputs, all_attentions, all_hidden_states = self.forward(input_ts, input_padding, freq, output_attentions=output_attentions, output_hidden_states=output_hidden_states) if return_forecast_on_context and step_index == 0: # For the first decodings step, collect the model forecast on the # context except the unavailable first input batch forecast. @@ -276,7 +297,7 @@ def decode( # `full_outputs` indexing starts at the forecast horizon. full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0:horizon_len, :] - return (full_outputs[:, :, 0], full_outputs) + return full_outputs[:, :, 0], full_outputs, all_attentions, all_hidden_states class TimesFMModel(TimesFMPreTrainedModel): @@ -321,7 +342,7 @@ def _preprocess( - the number of padded examples for SPMD so that each core has the same number (a multiple of `batch_size`) of examples. """ - + print(">>> TimesFMModel _preprocess", len(inputs), inputs[0].shape) input_ts, input_padding, inp_freq = [], [], [] pmap_pad = ( @@ -353,6 +374,8 @@ def _preprocess( input_padding.append(input_padding[-1]) inp_freq.append(inp_freq[-1]) + print(">>> TimesFMModel input_ts", len(input_ts), input_ts[0].shape) + return ( np.stack(input_ts, axis=0), np.stack(input_padding, axis=0), @@ -368,6 +391,9 @@ def forward( forecast_context_len: int | None = None, return_forecast_on_context: bool = False, truncate_negative: bool = False, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ) -> tuple[np.ndarray, np.ndarray]: """Forecasts on a list of time series. @@ -394,12 +420,15 @@ def forward( Raises: ValueError: If the checkpoint is not properly loaded. """ + if return_dict is None: + return_dict = self.config.use_return_dict if forecast_context_len is None: fcontext_len = self.context_len else: fcontext_len = forecast_context_len inputs = [np.array(ts)[-fcontext_len:] for ts in inputs] + print(">>> TimesFMModel forward", len(inputs), inputs[0].shape) inp_min = np.min([np.min(ts) for ts in inputs]) if window_size is not None: @@ -412,10 +441,18 @@ def forward( logging.info("No frequency provided via `freq`. Default to high (0).") freq = [0] * len(inputs) + if output_attentions is None: + output_attentions = self.config.output_attentions + if output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states + input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq) + print(">>> TimesFMModel input_ts", input_ts.shape) with torch.no_grad(): mean_outputs = [] full_outputs = [] + all_attentions = [] + all_hidden_states = [] assert input_ts.shape[0] % self.global_batch_size == 0 for i in range(input_ts.shape[0] // self.global_batch_size): input_ts_in = torch.from_numpy( @@ -449,12 +486,14 @@ def forward( dtype=np.int32, ) ).long() - mean_output, full_output = self.decoder.decode( + mean_output, full_output, attentions, hidden_states = self.decoder.decode( input_ts=input_ts_in, paddings=input_padding_in, freq=inp_freq_in, horizon_len=self.horizon_len, return_forecast_on_context=return_forecast_on_context, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, ) mean_output = mean_output.detach().cpu().numpy() full_output = full_output.detach().cpu().numpy() @@ -463,9 +502,36 @@ def forward( mean_outputs.append(mean_output) full_outputs.append(full_output) + if output_attentions: + if not all_attentions: + all_attentions = [[] for _ in range(len(attentions))] + for j in range(len(attentions)): + attentions[j] = attentions[j].detach().cpu().numpy() + attentions[j] = np.array(attentions[j]) + all_attentions[j].append(attentions[j]) + if output_hidden_states: + if not all_hidden_states: + all_hidden_states = [[] for _ in range(len(hidden_states))] + for j in range(len(hidden_states)): + hidden_states[j] = hidden_states[j].detach().cpu().numpy() + hidden_states[j] = np.array(hidden_states[j]) + all_hidden_states[j].append(hidden_states[j]) + mean_outputs = np.concatenate(mean_outputs, axis=0) full_outputs = np.concatenate(full_outputs, axis=0) + if output_attentions: + for j in range(len(all_attentions)): + all_attentions[j] = np.concatenate(all_attentions[j], axis=0) + if output_hidden_states: + for j in range(len(all_hidden_states)): + all_hidden_states[j] = np.concatenate(all_hidden_states[j], axis=0) + + if output_attentions: + print(">> TimesFMModel attentions", len(attentions), attentions[0].shape) + if output_hidden_states: + print(">> TimesFMModel hidden_states", len(hidden_states), hidden_states[0].shape) + if pmap_pad > 0: mean_outputs = mean_outputs[:-pmap_pad, ...] full_outputs = full_outputs[:-pmap_pad, ...] @@ -476,4 +542,21 @@ def forward( if inp_min >= 0 and truncate_negative: mean_outputs = np.maximum(mean_outputs, 0.0) full_outputs = np.maximum(full_outputs, 0.0) - return mean_outputs, full_outputs + + if return_dict: + result = TimesFMOutput() + result.mean_predictions = mean_outputs + result.full_predictions = full_outputs + if output_attentions: + result.attentions = all_attentions + if output_hidden_states: + result.hidden_states = all_hidden_states + + return result + else: + return_tuple = [mean_outputs, full_outputs] + if output_attentions: + return_tuple.append(all_attentions) + if output_hidden_states: + return_tuple.append(all_hidden_states) + return tuple(return_tuple) diff --git a/src/transformers/models/timesfm/timesfm_layers.py b/src/transformers/models/timesfm/timesfm_layers.py index 0ba6f2c6f54d..91fd460a120d 100644 --- a/src/transformers/models/timesfm/timesfm_layers.py +++ b/src/transformers/models/timesfm/timesfm_layers.py @@ -17,10 +17,11 @@ import math from typing import List, Tuple + import numpy as np import torch -from torch import nn import torch.nn.functional as F +from torch import nn def masked_mean_std( @@ -379,6 +380,7 @@ def forward( hidden_states_shape = hidden_states.shape assert len(hidden_states_shape) == 3 + print(">>> TimesFMAttention hidden_states_shape", hidden_states_shape) batch_size, input_len, _ = hidden_states_shape qkv = self.qkv_proj(hidden_states) @@ -461,6 +463,7 @@ def forward( kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: # Self Attention + print(">>> TimesFMDecoderLayer hidden_states", hidden_states.shape) residual = hidden_states hidden_states = self.input_layernorm(hidden_states) scores, hidden_states = self.self_attn( @@ -511,21 +514,32 @@ def forward( paddings: torch.Tensor, kv_write_indices: torch.Tensor | None = None, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] | None = None, + output_attentions: bool = False, + output_hidden_states: bool = False, ) -> torch.Tensor: + print(">>> StackedDecoder hidden_states", hidden_states.shape) padding_mask = convert_paddings_to_mask(paddings, hidden_states.dtype) atten_mask = causal_mask(hidden_states) mask = merge_masks(padding_mask, atten_mask) + all_attentions = [] + all_hidden_states = [] + for i in range(len(self.layers)): layer = self.layers[i] kv_cache = kv_caches[i] if kv_caches is not None else None - _, hidden_states = layer( + scores, hidden_states = layer( hidden_states=hidden_states, mask=mask, paddings=paddings, kv_write_indices=kv_write_indices, kv_cache=kv_cache, ) - return hidden_states + if output_attentions: + all_attentions.append(scores) + if output_hidden_states: + all_hidden_states.append(hidden_states) + + return hidden_states, all_attentions, all_hidden_states class PositionalEmbedding(torch.nn.Module): diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index ebf5df4d2cf3..f22042ec6c71 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -13,37 +13,36 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import numpy as np +import inspect import unittest from typing import List +import numpy as np +import torch + from transformers import TimesFMConfig, is_torch_available from transformers.testing_utils import ( - require_accelerate, - require_sentencepiece, - require_tokenizers, require_torch, - slow, torch_device, ) -from transformers.utils import cached_property, is_torch_fx_available +from transformers.utils import is_torch_fx_available -from ...generation.test_utils import GenerationTesterMixin +# from ...generation.test_utils import GenerationTesterMixin +# define our own GenerationTesters from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin -from ...test_pipeline_mixin import PipelineTesterMixin + + +# from ...test_pipeline_mixin import PipelineTesterMixin if is_torch_fx_available(): - from transformers.utils.fx import symbolic_trace + pass if is_torch_available(): - import torch from transformers import ( - AutoTokenizer, TimesFMModel, ) @@ -68,6 +67,7 @@ def __init__( use_positional_embedding: bool = True, per_core_batch_size: int = 32, initializer_factor: float = 1.0, + is_training: bool = False, ): self.parent = parent self.patch_len = patch_len @@ -78,14 +78,15 @@ def __init__( self.freq_size = freq_size self.model_dim = model_dim self.head_dim = head_dim - self.num_layers = num_layers - self.num_heads = num_heads + self.num_hidden_layers = num_layers + self.num_attention_heads = num_heads self.dropout_rate = dropout_rate self.tolerance = tolerance self.rms_norm_eps = rms_norm_eps self.use_positional_embedding = use_positional_embedding self.per_core_batch_size = per_core_batch_size self.initializer_factor = initializer_factor + self.is_training = is_training def get_large_model_config(self): return TimesFMConfig.from_pretrained("google/timesfm-1.0-200m-pytorch") @@ -100,8 +101,8 @@ def get_config(self): freq_size=self.freq_size, model_dim=self.model_dim, head_dim=self.head_dim, - num_layers=self.num_layers, - num_heads=self.num_heads, + num_layers=self.num_hidden_layers, + num_heads=self.num_attention_heads, dropout_rate=self.dropout_rate, tolerance=self.tolerance, rms_norm_eps=self.rms_norm_eps, @@ -145,7 +146,7 @@ def prepare_config_and_inputs_for_common(self): @require_torch class TimesFMModelTest( - ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase + ModelTesterMixin, unittest.TestCase ): all_model_classes = (TimesFMModel,) if is_torch_available() else () all_generative_model_classes = (TimesFMModel,) if is_torch_available() else () @@ -155,6 +156,7 @@ class TimesFMModelTest( test_resize_embeddings = False test_model_parallel = False is_encoder_decoder = False + test_inputs_embeds = False def setUp(self): self.model_tester = TimesFMModelTester(self) @@ -165,5 +167,89 @@ def test_create_and_run_model(self): model = TimesFMModel(config) model.to(torch_device) model.eval() - results = model.run_model(**inputs_dict) - assert results + results = model(**inputs_dict) + assert results.mean_predictions is not None + + def test_attention_outputs(self): + if not self.has_attentions: + self.skipTest(reason="Model does not output attentions") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + seq_len = getattr(self.model_tester, "seq_length", None) + decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len) + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) + decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + chunk_length = getattr(self.model_tester, "chunk_length", None) + if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): + encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + elif self.is_encoder_decoder: + added_hidden_states = 2 + else: + added_hidden_states = 1 + + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + if chunk_length is not None: + self.assertListEqual( + list(self_attentions[0].shape[-4:]), + [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], + ) + else: + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + + @unittest.skip(reason="Model does not have input embeddings") + def test_model_get_set_embeddings(self): + pass + + # the main input name is `inputs` + def test_model_main_input_name(self): + model_signature = inspect.signature(getattr(TimesFMModel, "forward")) + # The main input is the name of the argument after `self` + observed_main_input_name = list(model_signature.parameters.keys())[1] + self.assertEqual(TimesFMModel.main_input_name, observed_main_input_name) From d6357d5de8762df1c6708d02495abac06341ad3f Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Wed, 6 Nov 2024 15:29:13 -0800 Subject: [PATCH 019/242] fixing unit tests in progress --- .../models/timesfm/configuration_timesfm.py | 9 +++--- .../models/timesfm/modeling_timesfm.py | 32 +++++++++---------- tests/models/timesfm/test_modeling_timesfm.py | 24 +++++++++----- 3 files changed, 37 insertions(+), 28 deletions(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 6e6dc8aec307..0ff463ba270d 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -64,8 +64,8 @@ class TimesFMConfig(PretrainedConfig): The value used to pad the predictions. use_positional_embedding (`bool`, *optional*, defaults to `True`): Whether to add positional embeddings. - per_core_batch_size (`int`, *optional*, defaults to 32): - The batch size per core for data parallelism. + batch_size (`int`, *optional*, defaults to 32): + The batch size. initializer_factor (`float`, *optional*, defaults to 1.0): A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing). @@ -96,7 +96,7 @@ def __init__( quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], pad_val: float = 1123581321.0, use_positional_embedding: bool = True, - per_core_batch_size: int = 32, + batch_size: int = 32, initializer_factor: float = 1.0, **kwargs, ): @@ -114,10 +114,11 @@ def __init__( self.tolerance = tolerance self.rms_norm_eps = rms_norm_eps self.use_positional_embedding = use_positional_embedding - self.per_core_batch_size = per_core_batch_size + self.batch_size = batch_size self.initializer_factor = initializer_factor super().__init__( + is_encoder_decoder=self.is_encoder_decoder, **kwargs, ) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 8757647075f6..f9f0a9a8ce30 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -50,11 +50,14 @@ class TimesFMPreTrainedModel(PreTrainedModel): main_input_name = "inputs" def _init_weights(self, module): + print(">>> TimesFMPreTrainedModel _init_weights") if isinstance(module, nn.Embedding): - nn.init.uniform_(module.weight, a=-0.1, b=0.1) + print(">>> TimesFMPreTrainedModel Embedding std", self.config.initializer_factor) + module.weight.data.normal_(mean=0, std=self.config.initializer_factor) elif isinstance(module, nn.Linear): - nn.init.xavier_uniform_(module.weight) + print(">>> TimesFMPreTrainedModel Linear std", self.config.initializer_factor) + module.weight.data.normal_(mean=0, std=self.config.initializer_factor) if module.bias is not None: nn.init.zeros_(module.bias) @@ -316,10 +319,7 @@ def __init__(self, config: TimesFMConfig): self.model_dims = config.model_dim self.quantiles = config.quantiles self.num_heads = config.num_heads - - self.num_cores = 1 - self.per_core_batch_size = config.per_core_batch_size - self.global_batch_size = config.per_core_batch_size * self.num_cores + self.batch_size = config.batch_size self._horizon_start = self.context_len - self.input_patch_len def _preprocess( @@ -346,8 +346,8 @@ def _preprocess( input_ts, input_padding, inp_freq = [], [], [] pmap_pad = ( - (len(inputs) - 1) // self.global_batch_size + 1 - ) * self.global_batch_size - len(inputs) + (len(inputs) - 1) // self.batch_size + 1 + ) * self.batch_size - len(inputs) for i, ts in enumerate(inputs): input_len = ts.shape[0] @@ -453,14 +453,14 @@ def forward( full_outputs = [] all_attentions = [] all_hidden_states = [] - assert input_ts.shape[0] % self.global_batch_size == 0 - for i in range(input_ts.shape[0] // self.global_batch_size): + assert input_ts.shape[0] % self.batch_size == 0 + for i in range(input_ts.shape[0] // self.batch_size): input_ts_in = torch.from_numpy( np.array( input_ts[ i - * self.global_batch_size : (i + 1) - * self.global_batch_size + * self.batch_size : (i + 1) + * self.batch_size ], dtype=np.float32, ) @@ -469,8 +469,8 @@ def forward( np.array( input_padding[ i - * self.global_batch_size : (i + 1) - * self.global_batch_size + * self.batch_size : (i + 1) + * self.batch_size ], dtype=np.float32, ) @@ -479,8 +479,8 @@ def forward( np.array( inp_freq[ i - * self.global_batch_size : (i + 1) - * self.global_batch_size, + * self.batch_size : (i + 1) + * self.batch_size, :, ], dtype=np.int32, diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index f22042ec6c71..645ad268448e 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -55,18 +55,18 @@ def __init__( context_len: int = 512, horizon_len: int = 128, freq_size: int = 3, - num_layers: int = 20, - model_dim: int = 1280, - head_dim: int = 80, - num_heads: int = 16, + num_layers: int = 4, + model_dim: int = 128, + head_dim: int = 16, + num_heads: int = 4, dropout_rate: float = 0.1, tolerance: float = 1e-6, rms_norm_eps: float = 1e-6, quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], pad_val: float = 1123581321.0, use_positional_embedding: bool = True, - per_core_batch_size: int = 32, - initializer_factor: float = 1.0, + batch_size: int = 32, + initializer_factor: float = 0.0, is_training: bool = False, ): self.parent = parent @@ -84,9 +84,13 @@ def __init__( self.tolerance = tolerance self.rms_norm_eps = rms_norm_eps self.use_positional_embedding = use_positional_embedding - self.per_core_batch_size = per_core_batch_size + self.batch_size = batch_size self.initializer_factor = initializer_factor self.is_training = is_training + + # The size of test input + self.seq_length = context_len // patch_len + self.hidden_size = model_dim def get_large_model_config(self): return TimesFMConfig.from_pretrained("google/timesfm-1.0-200m-pytorch") @@ -107,7 +111,7 @@ def get_config(self): tolerance=self.tolerance, rms_norm_eps=self.rms_norm_eps, use_positional_embedding=self.use_positional_embedding, - per_core_batch_size=self.per_core_batch_size, + batch_size=self.batch_size, initializer_factor=self.initializer_factor, ) @@ -247,6 +251,10 @@ def test_attention_outputs(self): def test_model_get_set_embeddings(self): pass + @unittest.skip(reason="Model does not have head mask") + def test_headmasking(self): + pass + # the main input name is `inputs` def test_model_main_input_name(self): model_signature = inspect.signature(getattr(TimesFMModel, "forward")) From fba8f52bc6aabbb072a39d8b732a16bc7d33560e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 7 Nov 2024 19:29:56 +0100 Subject: [PATCH 020/242] add post_init --- .../models/timesfm/modeling_timesfm.py | 16 +++++++++++++++- tests/models/timesfm/test_modeling_timesfm.py | 2 +- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index f9f0a9a8ce30..dc9d36908736 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -33,7 +33,15 @@ from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from .configuration_timesfm import TimesFMConfig -from .timesfm_layers import * +from .timesfm_layers import ( + PositionalEmbedding, + ResidualBlock, + RMSNorm, + StackedDecoder, + masked_mean_std, + moving_average, + shift_padded_seq, +) @dataclass @@ -105,6 +113,9 @@ def __init__(self, config: TimesFMConfig): self.position_emb = PositionalEmbedding( embedding_dims=self.config.model_dim, ) + + # Initialize weights and apply final processing + self.post_init() def _forward_transform( self, inputs: torch.Tensor, patched_pads: torch.Tensor @@ -322,6 +333,9 @@ def __init__(self, config: TimesFMConfig): self.batch_size = config.batch_size self._horizon_start = self.context_len - self.input_patch_len + # Initialize weights and apply final processing + self.post_init() + def _preprocess( self, inputs: Sequence[np.array], freq: Sequence[int] ) -> tuple[np.array, np.array, int]: diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index 645ad268448e..8f7853398147 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -87,7 +87,7 @@ def __init__( self.batch_size = batch_size self.initializer_factor = initializer_factor self.is_training = is_training - + # The size of test input self.seq_length = context_len // patch_len self.hidden_size = model_dim From 0c0e51abbaacd95b3d37a94b1ff460afe9889ece Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 7 Nov 2024 19:40:01 +0100 Subject: [PATCH 021/242] do not change TimesFMOutput --- .../models/timesfm/modeling_timesfm.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index dc9d36908736..e6ac77c6a418 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -558,15 +558,12 @@ def forward( full_outputs = np.maximum(full_outputs, 0.0) if return_dict: - result = TimesFMOutput() - result.mean_predictions = mean_outputs - result.full_predictions = full_outputs - if output_attentions: - result.attentions = all_attentions - if output_hidden_states: - result.hidden_states = all_hidden_states - - return result + return TimesFMOutput( + mean_predictions=mean_outputs, + full_predictions=full_outputs, + attentions=all_attentions if output_attentions else None, + hidden_states=all_hidden_states if output_hidden_states else None, + ) else: return_tuple = [mean_outputs, full_outputs] if output_attentions: From fb702f69f7a9dcbdf7545a474ba33ef08f83bfbd Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Wed, 13 Nov 2024 17:11:00 -0800 Subject: [PATCH 022/242] fixing unit tests --- .../models/timesfm/modeling_timesfm.py | 147 ++++++++---------- tests/models/timesfm/test_modeling_timesfm.py | 73 --------- 2 files changed, 69 insertions(+), 151 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index e6ac77c6a418..f34f0b64deb0 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -58,13 +58,10 @@ class TimesFMPreTrainedModel(PreTrainedModel): main_input_name = "inputs" def _init_weights(self, module): - print(">>> TimesFMPreTrainedModel _init_weights") if isinstance(module, nn.Embedding): - print(">>> TimesFMPreTrainedModel Embedding std", self.config.initializer_factor) module.weight.data.normal_(mean=0, std=self.config.initializer_factor) elif isinstance(module, nn.Linear): - print(">>> TimesFMPreTrainedModel Linear std", self.config.initializer_factor) module.weight.data.normal_(mean=0, std=self.config.initializer_factor) if module.bias is not None: nn.init.zeros_(module.bias) @@ -462,84 +459,77 @@ def forward( input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq) print(">>> TimesFMModel input_ts", input_ts.shape) - with torch.no_grad(): - mean_outputs = [] - full_outputs = [] - all_attentions = [] - all_hidden_states = [] - assert input_ts.shape[0] % self.batch_size == 0 - for i in range(input_ts.shape[0] // self.batch_size): - input_ts_in = torch.from_numpy( - np.array( - input_ts[ - i - * self.batch_size : (i + 1) - * self.batch_size - ], - dtype=np.float32, - ) + mean_outputs = [] + full_outputs = [] + all_attentions = [] + all_hidden_states = [] + assert input_ts.shape[0] % self.batch_size == 0 + for i in range(input_ts.shape[0] // self.batch_size): + input_ts_in = torch.from_numpy( + np.array( + input_ts[ + i + * self.batch_size : (i + 1) + * self.batch_size + ], + dtype=np.float32, ) - input_padding_in = torch.from_numpy( - np.array( - input_padding[ - i - * self.batch_size : (i + 1) - * self.batch_size - ], - dtype=np.float32, - ) + ) + input_padding_in = torch.from_numpy( + np.array( + input_padding[ + i + * self.batch_size : (i + 1) + * self.batch_size + ], + dtype=np.float32, ) - inp_freq_in = torch.from_numpy( - np.array( - inp_freq[ - i - * self.batch_size : (i + 1) - * self.batch_size, - :, - ], - dtype=np.int32, - ) - ).long() - mean_output, full_output, attentions, hidden_states = self.decoder.decode( - input_ts=input_ts_in, - paddings=input_padding_in, - freq=inp_freq_in, - horizon_len=self.horizon_len, - return_forecast_on_context=return_forecast_on_context, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, + ) + inp_freq_in = torch.from_numpy( + np.array( + inp_freq[ + i + * self.batch_size : (i + 1) + * self.batch_size, + :, + ], + dtype=np.int32, ) - mean_output = mean_output.detach().cpu().numpy() - full_output = full_output.detach().cpu().numpy() - mean_output = np.array(mean_output) - full_output = np.array(full_output) - mean_outputs.append(mean_output) - full_outputs.append(full_output) - - if output_attentions: - if not all_attentions: - all_attentions = [[] for _ in range(len(attentions))] - for j in range(len(attentions)): - attentions[j] = attentions[j].detach().cpu().numpy() - attentions[j] = np.array(attentions[j]) - all_attentions[j].append(attentions[j]) - if output_hidden_states: - if not all_hidden_states: - all_hidden_states = [[] for _ in range(len(hidden_states))] - for j in range(len(hidden_states)): - hidden_states[j] = hidden_states[j].detach().cpu().numpy() - hidden_states[j] = np.array(hidden_states[j]) - all_hidden_states[j].append(hidden_states[j]) - - mean_outputs = np.concatenate(mean_outputs, axis=0) - full_outputs = np.concatenate(full_outputs, axis=0) + ).long() + mean_output, full_output, attentions, hidden_states = self.decoder.decode( + input_ts=input_ts_in, + paddings=input_padding_in, + freq=inp_freq_in, + horizon_len=self.horizon_len, + return_forecast_on_context=return_forecast_on_context, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + mean_outputs.append(mean_output) + full_outputs.append(full_output) + + if output_attentions: + if not all_attentions: + all_attentions = [[] for _ in range(len(attentions))] + for j in range(len(attentions)): + attentions[j] = attentions[j] + all_attentions[j].append(attentions[j]) + if output_hidden_states: + if not all_hidden_states: + all_hidden_states = [[] for _ in range(len(hidden_states))] + for j in range(len(hidden_states)): + hidden_states[j] = hidden_states[j] + all_hidden_states[j].append(hidden_states[j]) + + mean_outputs = torch.cat(mean_outputs, axis=0) + full_outputs = torch.cat(full_outputs, axis=0) if output_attentions: for j in range(len(all_attentions)): - all_attentions[j] = np.concatenate(all_attentions[j], axis=0) + all_attentions[j] = torch.cat(all_attentions[j], axis=0) if output_hidden_states: for j in range(len(all_hidden_states)): - all_hidden_states[j] = np.concatenate(all_hidden_states[j], axis=0) + all_hidden_states[j] = torch.cat(all_hidden_states[j], axis=0) if output_attentions: print(">> TimesFMModel attentions", len(attentions), attentions[0].shape) @@ -554,8 +544,8 @@ def forward( mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] if inp_min >= 0 and truncate_negative: - mean_outputs = np.maximum(mean_outputs, 0.0) - full_outputs = np.maximum(full_outputs, 0.0) + mean_outputs = torch.maximum(mean_outputs, 0.0) + full_outputs = torch.maximum(full_outputs, 0.0) if return_dict: return TimesFMOutput( @@ -565,9 +555,10 @@ def forward( hidden_states=all_hidden_states if output_hidden_states else None, ) else: - return_tuple = [mean_outputs, full_outputs] - if output_attentions: - return_tuple.append(all_attentions) + return_tuple = [] if output_hidden_states: return_tuple.append(all_hidden_states) + if output_attentions: + return_tuple.append(all_attentions) + return_tuple += [mean_outputs, full_outputs] return tuple(return_tuple) diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index 8f7853398147..c928cf9aca8e 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -174,79 +174,6 @@ def test_create_and_run_model(self): results = model(**inputs_dict) assert results.mean_predictions is not None - def test_attention_outputs(self): - if not self.has_attentions: - self.skipTest(reason="Model does not output attentions") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.return_dict = True - - seq_len = getattr(self.model_tester, "seq_length", None) - decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len) - encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) - decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length) - encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) - chunk_length = getattr(self.model_tester, "chunk_length", None) - if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): - encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes - - for model_class in self.all_model_classes: - inputs_dict["output_attentions"] = True - inputs_dict["output_hidden_states"] = False - config.return_dict = True - model = model_class(config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) - - # check that output_attentions also work using config - del inputs_dict["output_attentions"] - config.output_attentions = True - model = model_class(config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) - - out_len = len(outputs) - - # Check attention is always last and order is fine - inputs_dict["output_attentions"] = True - inputs_dict["output_hidden_states"] = True - model = model_class(config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - - if hasattr(self.model_tester, "num_hidden_states_types"): - added_hidden_states = self.model_tester.num_hidden_states_types - elif self.is_encoder_decoder: - added_hidden_states = 2 - else: - added_hidden_states = 1 - - self.assertEqual(out_len + added_hidden_states, len(outputs)) - - self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - - self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) - if chunk_length is not None: - self.assertListEqual( - list(self_attentions[0].shape[-4:]), - [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], - ) - else: - self.assertListEqual( - list(self_attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], - ) - @unittest.skip(reason="Model does not have input embeddings") def test_model_get_set_embeddings(self): pass From 3110e9df984e317425f9d0a8e68d3fbde13b0712 Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Thu, 14 Nov 2024 17:48:25 -0800 Subject: [PATCH 023/242] all unit tests passed --- .../models/timesfm/modeling_timesfm.py | 124 +++++------------- 1 file changed, 31 insertions(+), 93 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index f34f0b64deb0..f33c169f2555 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -308,7 +308,7 @@ def decode( # `full_outputs` indexing starts at the forecast horizon. full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0:horizon_len, :] - return full_outputs[:, :, 0], full_outputs, all_attentions, all_hidden_states + return full_outputs[:, :, 0], full_outputs, fprop_outputs, all_attentions, all_hidden_states class TimesFMModel(TimesFMPreTrainedModel): @@ -356,10 +356,6 @@ def _preprocess( print(">>> TimesFMModel _preprocess", len(inputs), inputs[0].shape) input_ts, input_padding, inp_freq = [], [], [] - pmap_pad = ( - (len(inputs) - 1) // self.batch_size + 1 - ) * self.batch_size - len(inputs) - for i, ts in enumerate(inputs): input_len = ts.shape[0] padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float) @@ -379,19 +375,12 @@ def _preprocess( input_padding.append(padding) inp_freq.append(freq[i]) - # Padding the remainder batch. - for _ in range(pmap_pad): - input_ts.append(input_ts[-1]) - input_padding.append(input_padding[-1]) - inp_freq.append(inp_freq[-1]) - print(">>> TimesFMModel input_ts", len(input_ts), input_ts[0].shape) return ( np.stack(input_ts, axis=0), np.stack(input_padding, axis=0), np.array(inp_freq).astype(np.int32).reshape(-1, 1), - pmap_pad, ) def forward( @@ -457,88 +446,36 @@ def forward( if output_hidden_states is None: output_hidden_states = self.config.output_hidden_states - input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq) + input_ts, input_padding, inp_freq = self._preprocess(inputs, freq) print(">>> TimesFMModel input_ts", input_ts.shape) - mean_outputs = [] - full_outputs = [] - all_attentions = [] - all_hidden_states = [] - assert input_ts.shape[0] % self.batch_size == 0 - for i in range(input_ts.shape[0] // self.batch_size): - input_ts_in = torch.from_numpy( - np.array( - input_ts[ - i - * self.batch_size : (i + 1) - * self.batch_size - ], - dtype=np.float32, - ) + + input_ts_in = torch.from_numpy( + np.array( + input_ts, + dtype=np.float32, ) - input_padding_in = torch.from_numpy( - np.array( - input_padding[ - i - * self.batch_size : (i + 1) - * self.batch_size - ], - dtype=np.float32, - ) + ) + input_padding_in = torch.from_numpy( + np.array( + input_padding, + dtype=np.float32, ) - inp_freq_in = torch.from_numpy( - np.array( - inp_freq[ - i - * self.batch_size : (i + 1) - * self.batch_size, - :, - ], - dtype=np.int32, - ) - ).long() - mean_output, full_output, attentions, hidden_states = self.decoder.decode( - input_ts=input_ts_in, - paddings=input_padding_in, - freq=inp_freq_in, - horizon_len=self.horizon_len, - return_forecast_on_context=return_forecast_on_context, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, + ) + inp_freq_in = torch.from_numpy( + np.array( + inp_freq, + dtype=np.int32, ) - mean_outputs.append(mean_output) - full_outputs.append(full_output) - - if output_attentions: - if not all_attentions: - all_attentions = [[] for _ in range(len(attentions))] - for j in range(len(attentions)): - attentions[j] = attentions[j] - all_attentions[j].append(attentions[j]) - if output_hidden_states: - if not all_hidden_states: - all_hidden_states = [[] for _ in range(len(hidden_states))] - for j in range(len(hidden_states)): - hidden_states[j] = hidden_states[j] - all_hidden_states[j].append(hidden_states[j]) - - mean_outputs = torch.cat(mean_outputs, axis=0) - full_outputs = torch.cat(full_outputs, axis=0) - - if output_attentions: - for j in range(len(all_attentions)): - all_attentions[j] = torch.cat(all_attentions[j], axis=0) - if output_hidden_states: - for j in range(len(all_hidden_states)): - all_hidden_states[j] = torch.cat(all_hidden_states[j], axis=0) - - if output_attentions: - print(">> TimesFMModel attentions", len(attentions), attentions[0].shape) - if output_hidden_states: - print(">> TimesFMModel hidden_states", len(hidden_states), hidden_states[0].shape) - - if pmap_pad > 0: - mean_outputs = mean_outputs[:-pmap_pad, ...] - full_outputs = full_outputs[:-pmap_pad, ...] + ).long() + mean_outputs, full_outputs, last_hidden_state, all_attentions, all_hidden_states = self.decoder.decode( + input_ts=input_ts_in, + paddings=input_padding_in, + freq=inp_freq_in, + horizon_len=self.horizon_len, + return_forecast_on_context=return_forecast_on_context, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) if window_size is not None: mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] @@ -549,13 +486,14 @@ def forward( if return_dict: return TimesFMOutput( - mean_predictions=mean_outputs, - full_predictions=full_outputs, + last_hidden_state=last_hidden_state, attentions=all_attentions if output_attentions else None, hidden_states=all_hidden_states if output_hidden_states else None, + mean_predictions=mean_outputs, + full_predictions=full_outputs, ) else: - return_tuple = [] + return_tuple = [last_hidden_state] if output_hidden_states: return_tuple.append(all_hidden_states) if output_attentions: From ec187e7851cbab68f3cb63ac996c1e313cfc52b1 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 24 Nov 2024 16:07:49 +0100 Subject: [PATCH 024/242] remove timesfm_layers --- src/transformers/models/timesfm/__init__.py | 4 +- .../models/timesfm/modeling_timesfm.py | 605 ++++++++++++++++-- .../models/timesfm/timesfm_layers.py | 597 ----------------- tests/models/timesfm/test_modeling_timesfm.py | 1 - 4 files changed, 550 insertions(+), 657 deletions(-) delete mode 100644 src/transformers/models/timesfm/timesfm_layers.py diff --git a/src/transformers/models/timesfm/__init__.py b/src/transformers/models/timesfm/__init__.py index 82bbb6be22ce..6592a5b1620e 100644 --- a/src/transformers/models/timesfm/__init__.py +++ b/src/transformers/models/timesfm/__init__.py @@ -51,6 +51,4 @@ else: import sys - sys.modules[__name__] = _LazyModule( - __name__, globals()["__file__"], _import_structure, module_spec=__spec__ - ) + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index f33c169f2555..9e0959b8e12f 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -21,27 +21,19 @@ # - PreTrainedModel for the models (it-self a sub-class of nn.Module) #################################################### - import logging +import math from dataclasses import dataclass -from typing import Any, Sequence +from typing import Any, List, Sequence, Tuple import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from .configuration_timesfm import TimesFMConfig -from .timesfm_layers import ( - PositionalEmbedding, - ResidualBlock, - RMSNorm, - StackedDecoder, - masked_mean_std, - moving_average, - shift_padded_seq, -) @dataclass @@ -50,6 +42,521 @@ class TimesFMOutput(BaseModelOutput): full_predictions: np.ndarray = None +class TimesFMTransformerMLP(nn.Module): + """Pax transformer MLP in pytorch.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + ): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size) + self.down_proj = nn.Linear(intermediate_size, hidden_size) + self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6) + + def forward(self, x, paddings=None): + gate_inp = self.layer_norm(x) + gate = self.gate_proj(gate_inp) + gate = F.relu(gate) + outputs = self.down_proj(gate) + if paddings is not None: + outputs = outputs * (1.0 - paddings[:, :, None]) + return outputs + x + + +class TimesFMResidualBlock(nn.Module): + """TimesFM residual block.""" + + def __init__( + self, + input_dims, + hidden_dims, + output_dims, + ): + super().__init__() + self.input_dims = input_dims + self.hidden_dims = hidden_dims + self.output_dims = output_dims + + # Hidden Layer + self.hidden_layer = nn.Sequential( + nn.Linear(input_dims, hidden_dims), + nn.SiLU(), + ) + + # Output Layer + self.output_layer = nn.Linear(hidden_dims, output_dims) + # Residual Layer + self.residual_layer = nn.Linear(input_dims, output_dims) + + def forward(self, x): + hidden = self.hidden_layer(x) + output = self.output_layer(hidden) + residual = self.residual_layer(x) + return output + residual + + +class TimesFMRMSNorm(torch.nn.Module): + """Pax rms norm in pytorch.""" + + def __init__( + self, + dim: int, + eps: float = 1e-6, + add_unit_offset: bool = False, + ): + super().__init__() + self.eps = eps + self.add_unit_offset = add_unit_offset + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + if self.add_unit_offset: + output = output * (1 + self.weight.float()) + else: + output = output * self.weight.float() + return output.type_as(x) + + +class TimesFMPositionalEmbedding(nn.Module): + """Generates position embedding for a given 1-d sequence. + + Attributes: + min_timescale: Start of the geometric index. Determines the periodicity of + the added signal. + max_timescale: End of the geometric index. Determines the frequency of the + added signal. + embedding_dims: Dimension of the embedding to be generated. + """ + + def __init__( + self, + embedding_dims: int, + min_timescale: int = 1, + max_timescale: int = 10_000, + ) -> None: + super().__init__() + self.min_timescale = min_timescale + self.max_timescale = max_timescale + self.embedding_dims = embedding_dims + + def forward(self, seq_length=None, position=None): + """Generates a Tensor of sinusoids with different frequencies. + + Args: + seq_length: an optional Python int defining the output sequence length. + if the `position` argument is specified. + position: [B, seq_length], optional position for each token in the + sequence, only required when the sequence is packed. + + Returns: + [B, seqlen, D] if `position` is specified, else [1, seqlen, D] + """ + if position is None: + assert seq_length is not None + # [1, seqlen] + position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0) + else: + assert position.ndim == 2, position.shape + + num_timescales = self.embedding_dims // 2 + log_timescale_increment = math.log(float(self.max_timescale) / float(self.min_timescale)) / max( + num_timescales - 1, 1 + ) + inv_timescales = self.min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(2) * inv_timescales.unsqueeze(0).unsqueeze(0) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) + # Padding to ensure correct embedding dimension + signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) + return signal + + +class TimesFMAttention(nn.Module): + """Implements the attention used in TimesFM.""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + ): + super().__init__() + + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.hidden_size = hidden_size + self.head_dim = head_dim + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = nn.Parameter( + torch.empty((self.head_dim,), dtype=torch.float32), + ) + + self.qkv_proj = nn.Linear( + self.hidden_size, + (self.num_heads + 2 * self.num_kv_heads) * self.head_dim, + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size) + + def _per_dim_scaling(self, query: torch.Tensor) -> torch.Tensor: + # [batch_size, n_local_heads, input_len, head_dim] + r_softplus_0 = 1.442695041 + softplus_func = torch.nn.Softplus() + scale = r_softplus_0 / math.sqrt(self.head_dim) + scale = scale * softplus_func(self.scaling) + return query * scale[None, None, None, :] + + def forward( + self, + hidden_states: torch.Tensor, + mask: torch.Tensor, + kv_write_indices: torch.Tensor | None = None, + kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + hidden_states_shape = hidden_states.shape + assert len(hidden_states_shape) == 3 + + batch_size, input_len, _ = hidden_states_shape + + qkv = self.qkv_proj(hidden_states) + xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + xq = xq.view(batch_size, -1, self.num_heads, self.head_dim) + xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim) + xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim) + xq = self._per_dim_scaling(xq) + + # Write new kv cache. + # [batch_size, input_len, n_local_kv_heads, head_dim] + if kv_cache is not None and kv_write_indices is not None: + k_cache, v_cache = kv_cache + k_cache.index_copy_(1, kv_write_indices, xk) + v_cache.index_copy_(1, kv_write_indices, xv) + + key = k_cache + value = v_cache + else: + key = xk + value = xv + if self.num_kv_heads != self.num_heads: + # [batch_size, max_seq_len, n_local_heads, head_dim] + key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) + value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2) + + # [batch_size, n_local_heads, input_len, head_dim] + q = xq.transpose(1, 2) + # [batch_size, n_local_heads, max_seq_len, head_dim] + k = key.transpose(1, 2) + v = value.transpose(1, 2) + + # [batch_size, n_local_heads, input_len, max_seq_len] + scores = torch.matmul(q, k.transpose(2, 3)) + scores = scores + mask + scores = F.softmax(scores.float(), dim=-1).type_as(q) + + # [batch_size, n_local_heads, input_len, head_dim] + output = torch.matmul(scores, v) + # return scores, output.transpose(1, 2).contiguous() + + # [batch_size, input_len, hidden_dim] + output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) + output = self.o_proj(output) + return scores, output + + +class TimesFMDecoderLayer(nn.Module): + """Transformer layer.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + rms_norm_eps: float = 1e-6, + ): + super().__init__() + self.self_attn = TimesFMAttention( + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + ) + self.mlp = TimesFMTransformerMLP( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + ) + self.input_layernorm = TimesFMRMSNorm(hidden_size, eps=rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + mask: torch.Tensor, + paddings: torch.Tensor, + kv_write_indices: torch.Tensor | None = None, + kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + scores, hidden_states = self.self_attn( + hidden_states=hidden_states, + mask=mask, + kv_write_indices=kv_write_indices, + kv_cache=kv_cache, + ) + hidden_states = residual + hidden_states + + # MLP + hidden_states = self.mlp(hidden_states, paddings=paddings) + + return scores, hidden_states + + +class TimesFMStackedDecoder(nn.Module): + """Stacked transformer layer.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + num_layers: int, + rms_norm_eps: float = 1e-6, + ): + super().__init__() + + self.layers = nn.ModuleList() + for _ in range(num_layers): + self.layers.append( + TimesFMDecoderLayer( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + rms_norm_eps=rms_norm_eps, + ) + ) + + def forward( + self, + hidden_states: torch.Tensor, + paddings: torch.Tensor, + kv_write_indices: torch.Tensor | None = None, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] | None = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + ) -> torch.Tensor: + padding_mask = timesfm_convert_paddings_to_mask(paddings, hidden_states.dtype) + atten_mask = timesfm_causal_mask(hidden_states) + mask = timesfm_merge_masks(padding_mask, atten_mask) + all_attentions = [] + all_hidden_states = [] + + for i in range(len(self.layers)): + layer = self.layers[i] + kv_cache = kv_caches[i] if kv_caches is not None else None + scores, hidden_states = layer( + hidden_states=hidden_states, + mask=mask, + paddings=paddings, + kv_write_indices=kv_write_indices, + kv_cache=kv_cache, + ) + if output_attentions: + all_attentions.append(scores) + if output_hidden_states: + all_hidden_states.append(hidden_states) + + return hidden_states, all_attentions, all_hidden_states + + +# Move utility functions here +def timesfm_masked_mean_std(inputs: torch.Tensor, padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Calculates mean and standard deviation of `inputs` across axis 1. + + It excludes values where `padding` is 1. + + Args: + inputs: A PyTorch tensor of shape [b, n, p]. + padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1. + + Returns: + A tuple containing the mean and standard deviation. + We return the statistics of the first patch with more than three non-padded values. + """ + + # Selecting the first patch with more than 3 unpadded values. + def _get_patch_index(arr: torch.Tensor): + indices = torch.argmax((arr >= 3).to(torch.int32), dim=1) + row_sum = (arr >= 3).to(torch.int32).sum(dim=1) + return torch.where(row_sum == 0, arr.shape[1] - 1, indices) + + pad_sum = torch.sum(1 - padding, dim=2) + patch_indices = _get_patch_index(pad_sum) + bidxs = torch.arange(inputs.shape[0]) + + arr = inputs[bidxs, patch_indices, :] + pad = padding[bidxs, patch_indices, :] + + # Create a mask where padding is 0 + mask = 1 - pad + + # Calculate the number of valid elements + num_valid_elements = torch.sum(mask, dim=1) + num_valid_elements = torch.where( + num_valid_elements == 0, + torch.tensor(1, dtype=num_valid_elements.dtype, device=num_valid_elements.device), + num_valid_elements, + ) + + # Calculate the masked sum and squared sum + masked_sum = torch.sum(arr * mask, dim=1) + masked_squared_sum = torch.sum((arr * mask) ** 2, dim=1) + + # Calculate the masked mean and standard deviation + masked_mean = masked_sum / num_valid_elements + masked_var = masked_squared_sum / num_valid_elements - masked_mean**2 + masked_var = torch.where( + masked_var < 0.0, + torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device), + masked_var, + ) + masked_std = torch.sqrt(masked_var) + + return masked_mean, masked_std + + +def timesfm_shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor: + """Shifts rows of seq based on the first 0 in each row of the mask. + + Args: + mask: mask tensor of shape [B, N] + seq: seq tensor of shape [B, N, P] + + Returns: + The shifted sequence. + """ + batch_size, num_seq, feature_dim = seq.shape + + new_mask: torch.BoolTensor = mask == 0 + + # Use argmax to find the first True value in each row + indices = new_mask.to(torch.int32).argmax(dim=1) + + # Handle rows with all zeros + indices[~new_mask.any(dim=1)] = -1 + + # Create index ranges for each sequence in the batch + idx_range = torch.arange(num_seq).to(seq.device).unsqueeze(0).unsqueeze(-1).expand(batch_size, -1, feature_dim) + + # Calculate shifted indices for each element in each sequence + shifted_idx = (idx_range - indices[:, None, None]) % num_seq + + # Gather values from seq using shifted indices + shifted_seq = seq.gather(1, shifted_idx) + + return shifted_seq + + +def timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]: + """Calculates the moving average using NumPy's convolution function.""" + # Pad with zeros to handle initial window positions + arr_padded = np.pad(arr, (window_size - 1, 0), "constant") + smoothed_arr = np.convolve(arr_padded, np.ones(window_size), "valid") / window_size + return [smoothed_arr, arr - smoothed_arr] + + +def timesfm_get_large_negative_number(dtype: torch.dtype) -> torch.Tensor: + """Returns a large negative value for the given dtype.""" + if dtype.is_floating_point: + dtype_max = torch.finfo(dtype).max + else: + dtype_max = torch.iinfo(dtype).max + return torch.tensor(-0.7 * dtype_max, dtype=dtype) + + +def timesfm_causal_mask(input_t: torch.Tensor) -> torch.Tensor: + """Computes and returns causal mask. + + Args: + input_t: A torch.Tensor of shape [B, T, D]. + + Returns: + An attention_mask torch.Tensor of shape [1, 1, T, T]. Attention mask has + already been converted to large negative values. + """ + assert input_t.dtype.is_floating_point, input_t.dtype + large_negative_number = timesfm_get_large_negative_number(input_t.dtype) + t = input_t.shape[1] + col_idx = torch.arange(t).unsqueeze(0).repeat(t, 1) + row_idx = torch.arange(t).unsqueeze(1).repeat(1, t) + mask = (row_idx < col_idx).to(input_t.dtype) * large_negative_number + return mask.unsqueeze(0).unsqueeze(0).to(input_t.device) # Equivalent to jnp.newaxis + + +def timesfm_convert_paddings_to_mask(paddings: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Converts binary paddings to a logit mask ready to add to attention matrix. + + Args: + paddings: binary torch.Tensor of shape [B, T], with 1 denoting padding + token. + dtype: data type of the input. + + Returns: + A torch.Tensor of shape [B, 1, 1, T] ready to add to attention logits. + """ + attention_mask = paddings.detach().clone() + attention_mask = attention_mask[:, None, None, :] # Equivalent to jnp.newaxis + attention_mask *= timesfm_get_large_negative_number(dtype) + return attention_mask + + +def timesfm_merge_masks(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """Merges 2 masks. + + logscale mask is expected but 0/1 mask is also fine. + + Args: + a: torch.Tensor of shape [1|B, 1, 1|T, S]. + b: torch.Tensor of shape [1|B, 1, 1|T, S]. + + Returns: + torch.Tensor of shape [1|B, 1, 1|T, S]. + """ + + def expand_t(key_mask): + query_mask = key_mask.transpose(-1, -2) # Equivalent of jnp.transpose + return torch.minimum(query_mask, key_mask) + + if a.shape[2] != b.shape[2]: + if a.shape[2] == 1: + a = expand_t(a) + else: + assert b.shape[2] == 1 + b = expand_t(b) + + assert a.shape[1:] == b.shape[1:], f"a.shape={a.shape}, b.shape={b.shape}." + return torch.minimum(a, b) # Element-wise minimum, similar to jnp.minimum + + class TimesFMPreTrainedModel(PreTrainedModel): """handles the loading for all models.""" @@ -70,10 +577,10 @@ def _init_weights(self, module): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) - elif isinstance(module, RMSNorm): + elif isinstance(module, TimesFMRMSNorm): nn.init.zeros_(module.weight) - elif isinstance(module, PositionalEmbedding): + elif isinstance(module, TimesFMPositionalEmbedding): pass @@ -84,20 +591,18 @@ def __init__(self, config: TimesFMConfig): super().__init__(config) self.config = config - self.input_ff_layer = ResidualBlock( + self.input_ff_layer = TimesFMResidualBlock( input_dims=2 * config.patch_len, output_dims=config.model_dim, hidden_dims=config.model_dim, ) - self.freq_emb = nn.Embedding( - num_embeddings=config.freq_size, embedding_dim=config.model_dim - ) - self.horizon_ff_layer = ResidualBlock( + self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.model_dim) + self.horizon_ff_layer = TimesFMResidualBlock( input_dims=config.model_dim, output_dims=config.horizon_len * (1 + len(config.quantiles)), hidden_dims=config.model_dim, ) - self.stacked_transformer = StackedDecoder( + self.stacked_transformer = TimesFMStackedDecoder( hidden_size=self.config.model_dim, intermediate_size=self.config.model_dim, num_heads=self.config.num_heads, @@ -107,10 +612,10 @@ def __init__(self, config: TimesFMConfig): rms_norm_eps=self.config.rms_norm_eps, ) if self.config.use_positional_embedding: - self.position_emb = PositionalEmbedding( + self.position_emb = TimesFMPositionalEmbedding( embedding_dims=self.config.model_dim, ) - + # Initialize weights and apply final processing self.post_init() @@ -118,7 +623,7 @@ def _forward_transform( self, inputs: torch.Tensor, patched_pads: torch.Tensor ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """Input is of shape [B, N, P].""" - mu, sigma = masked_mean_std(inputs, patched_pads) + mu, sigma = timesfm_masked_mean_std(inputs, patched_pads) sigma = torch.where( sigma < self.config.tolerance, torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device), @@ -129,16 +634,12 @@ def _forward_transform( outputs = (inputs - mu[:, None, None]) / sigma[:, None, None] outputs = torch.where( torch.abs(inputs - self.config.pad_val) < self.config.tolerance, - torch.tensor( - self.config.pad_val, dtype=outputs.dtype, device=outputs.device - ), + torch.tensor(self.config.pad_val, dtype=outputs.dtype, device=outputs.device), outputs, ) return outputs, (mu, sigma) - def _reverse_transform( - self, outputs: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor] - ) -> torch.Tensor: + def _reverse_transform(self, outputs: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: """Output is of shape [B, N, P, Q].""" mu, sigma = stats return outputs * sigma[:, None, None, None] + mu[:, None, None, None] @@ -174,19 +675,15 @@ def _preprocess_input( # B x N x D patched_inputs = patched_inputs * (1.0 - patched_pads) - print(">>> PatchedDecoder patched_inputs", patched_inputs.shape) concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1) - print(">>> PatchedDecoder concat_inputs", concat_inputs.shape) model_input = self.input_ff_layer(concat_inputs) # A patch should not be padded even if there is at least one zero. - patched_padding = torch.min(patched_pads, dim=-1)[ - 0 - ] # Get the values from the min result + patched_padding = torch.min(patched_pads, dim=-1)[0] # Get the values from the min result if self.config.use_positional_embedding: pos_emb = self.position_emb(model_input.shape[1]).to(model_input.device) pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0) - pos_emb = shift_padded_seq(patched_padding, pos_emb) + pos_emb = timesfm_shift_padded_seq(patched_padding, pos_emb) model_input += pos_emb return model_input, patched_padding, stats, patched_inputs @@ -216,7 +713,6 @@ def forward( output_attentions: bool = False, output_hidden_states: bool = False, ) -> torch.Tensor: - print(">>> PatchedDecoder input_ts", input_ts.shape) num_outputs = len(self.config.quantiles) + 1 model_input, patched_padding, stats, _ = self._preprocess_input( input_ts=input_ts, @@ -225,8 +721,12 @@ def forward( f_emb = self.freq_emb(freq) # B x 1 x D model_input += f_emb - print(">>> PatchedDecoder model_input", model_input.shape) - model_output, all_attentions, all_hidden_states = self.stacked_transformer(model_input, patched_padding, output_attentions=output_attentions, output_hidden_states=output_hidden_states) + model_output, all_attentions, all_hidden_states = self.stacked_transformer( + model_input, + patched_padding, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) if output_hidden_states: all_hidden_states = [model_input] + all_hidden_states @@ -281,14 +781,18 @@ def decode( current_padding = paddings[:, 0 : final_out.shape[1]] input_ts = final_out[:, -max_len:] input_padding = current_padding[:, -max_len:] - fprop_outputs, all_attentions, all_hidden_states = self.forward(input_ts, input_padding, freq, output_attentions=output_attentions, output_hidden_states=output_hidden_states) + fprop_outputs, all_attentions, all_hidden_states = self.forward( + input_ts, + input_padding, + freq, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) if return_forecast_on_context and step_index == 0: # For the first decodings step, collect the model forecast on the # context except the unavailable first input batch forecast. new_full_ts = fprop_outputs[:, :-1, : self.config.patch_len, :] - new_full_ts = fprop_outputs.view( - new_full_ts.size(0), -1, new_full_ts.size(3) - ) + new_full_ts = fprop_outputs.view(new_full_ts.size(0), -1, new_full_ts.size(3)) full_outputs.append(new_full_ts) @@ -333,9 +837,7 @@ def __init__(self, config: TimesFMConfig): # Initialize weights and apply final processing self.post_init() - def _preprocess( - self, inputs: Sequence[np.array], freq: Sequence[int] - ) -> tuple[np.array, np.array, int]: + def _preprocess(self, inputs: Sequence[np.array], freq: Sequence[int]) -> tuple[np.array, np.array, int]: """Formats and pads raw inputs to feed into the model. This function both pads each time series to match the context length, and @@ -353,7 +855,6 @@ def _preprocess( - the number of padded examples for SPMD so that each core has the same number (a multiple of `batch_size`) of examples. """ - print(">>> TimesFMModel _preprocess", len(inputs), inputs[0].shape) input_ts, input_padding, inp_freq = [], [], [] for i, ts in enumerate(inputs): @@ -361,12 +862,8 @@ def _preprocess( padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float) if input_len < self.context_len: num_front_pad = self.context_len - input_len - ts = np.concatenate( - [np.zeros(shape=(num_front_pad,), dtype=float), ts], axis=0 - ) - padding = np.concatenate( - [np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0 - ) + ts = np.concatenate([np.zeros(shape=(num_front_pad,), dtype=float), ts], axis=0) + padding = np.concatenate([np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0) elif input_len > self.context_len: ts = ts[-self.context_len :] padding = padding[-(self.context_len + self.horizon_len) :] @@ -375,8 +872,6 @@ def _preprocess( input_padding.append(padding) inp_freq.append(freq[i]) - print(">>> TimesFMModel input_ts", len(input_ts), input_ts[0].shape) - return ( np.stack(input_ts, axis=0), np.stack(input_padding, axis=0), @@ -428,13 +923,12 @@ def forward( else: fcontext_len = forecast_context_len inputs = [np.array(ts)[-fcontext_len:] for ts in inputs] - print(">>> TimesFMModel forward", len(inputs), inputs[0].shape) inp_min = np.min([np.min(ts) for ts in inputs]) if window_size is not None: new_inputs = [] for ts in inputs: - new_inputs.extend(moving_average(ts, window_size)) + new_inputs.extend(timesfm_moving_average(ts, window_size)) inputs = new_inputs if freq is None: @@ -447,7 +941,6 @@ def forward( output_hidden_states = self.config.output_hidden_states input_ts, input_padding, inp_freq = self._preprocess(inputs, freq) - print(">>> TimesFMModel input_ts", input_ts.shape) input_ts_in = torch.from_numpy( np.array( diff --git a/src/transformers/models/timesfm/timesfm_layers.py b/src/transformers/models/timesfm/timesfm_layers.py deleted file mode 100644 index 91fd460a120d..000000000000 --- a/src/transformers/models/timesfm/timesfm_layers.py +++ /dev/null @@ -1,597 +0,0 @@ -# coding=utf-8 -# Copyright 2024 Google LLC and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Pytorch version of patched decoder.""" - - -import math -from typing import List, Tuple - -import numpy as np -import torch -import torch.nn.functional as F -from torch import nn - - -def masked_mean_std( - inputs: torch.Tensor, padding: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor]: - """Calculates mean and standard deviation of `inputs` across axis 1. - - It excludes values where `padding` is 1. - - Args: - inputs: A PyTorch tensor of shape [b, n, p]. - padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1. - - Returns: - A tuple containing the mean and standard deviation. - We return the statistics of the first patch with more than three non-padded values. - """ - - # Selecting the first patch with more than 3 unpadded values. - def _get_patch_index(arr: torch.Tensor): - indices = torch.argmax((arr >= 3).to(torch.int32), dim=1) - row_sum = (arr >= 3).to(torch.int32).sum(dim=1) - return torch.where(row_sum == 0, arr.shape[1] - 1, indices) - - pad_sum = torch.sum(1 - padding, dim=2) - patch_indices = _get_patch_index(pad_sum) - bidxs = torch.arange(inputs.shape[0]) - - arr = inputs[bidxs, patch_indices, :] - pad = padding[bidxs, patch_indices, :] - - # Create a mask where padding is 0 - mask = 1 - pad - - # Calculate the number of valid elements - num_valid_elements = torch.sum(mask, dim=1) - num_valid_elements = torch.where( - num_valid_elements == 0, - torch.tensor( - 1, dtype=num_valid_elements.dtype, device=num_valid_elements.device - ), - num_valid_elements, - ) - - # Calculate the masked sum and squared sum - masked_sum = torch.sum(arr * mask, dim=1) - masked_squared_sum = torch.sum((arr * mask) ** 2, dim=1) - - # Calculate the masked mean and standard deviation - masked_mean = masked_sum / num_valid_elements - masked_var = masked_squared_sum / num_valid_elements - masked_mean**2 - masked_var = torch.where( - masked_var < 0.0, - torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device), - masked_var, - ) - masked_std = torch.sqrt(masked_var) - - return masked_mean, masked_std - - -def shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor: - """Shifts rows of seq based on the first 0 in each row of the mask. - - Args: - mask: mask tensor of shape [B, N] - seq: seq tensor of shape [B, N, P] - - Returns: - The shifted sequence. - """ - batch_size, num_seq, feature_dim = seq.shape - - new_mask: torch.BoolTensor = mask == 0 - - # Use argmax to find the first True value in each row - indices = new_mask.to(torch.int32).argmax(dim=1) - - # Handle rows with all zeros - indices[~new_mask.any(dim=1)] = -1 - - # Create index ranges for each sequence in the batch - idx_range = ( - torch.arange(num_seq) - .to(seq.device) - .unsqueeze(0) - .unsqueeze(-1) - .expand(batch_size, -1, feature_dim) - ) - - # Calculate shifted indices for each element in each sequence - shifted_idx = (idx_range - indices[:, None, None]) % num_seq - - # Gather values from seq using shifted indices - shifted_seq = seq.gather(1, shifted_idx) - - return shifted_seq - - -def get_large_negative_number(dtype: torch.dtype) -> torch.Tensor: - """Returns a large negative value for the given dtype.""" - if dtype.is_floating_point: - dtype_max = torch.finfo(dtype).max - else: - dtype_max = torch.iinfo(dtype).max - return torch.tensor(-0.7 * dtype_max, dtype=dtype) - - -def apply_mask_to_logits(logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: - """Applies a floating-point mask to a set of logits. - - Args: - logits: A torch.Tensor of logit values. - mask: A torch.Tensor (float32) of mask values with the encoding described - in the function documentation. - - Returns: - Masked logits. - """ - - min_value = get_large_negative_number(logits.dtype) - - return torch.where((mask >= min_value * 0.5), logits, min_value) - - -def convert_paddings_to_mask( - paddings: torch.Tensor, dtype: torch.dtype = torch.float32 -) -> torch.Tensor: - """Converts binary paddings to a logit mask ready to add to attention matrix. - - Args: - paddings: binary torch.Tensor of shape [B, T], with 1 denoting padding - token. - dtype: data type of the input. - - Returns: - A torch.Tensor of shape [B, 1, 1, T] ready to add to attention logits. - """ - attention_mask = paddings.detach().clone() - attention_mask = attention_mask[:, None, None, :] # Equivalent to jnp.newaxis - attention_mask *= get_large_negative_number(dtype) - return attention_mask - - -def causal_mask(input_t: torch.Tensor) -> torch.Tensor: - """Computes and returns causal mask. - - Args: - input_t: A torch.Tensor of shape [B, T, D]. - - Returns: - An attention_mask torch.Tensor of shape [1, 1, T, T]. Attention mask has - already been converted to large negative values. - """ - assert input_t.dtype.is_floating_point, input_t.dtype - large_negative_number = get_large_negative_number(input_t.dtype) - t = input_t.shape[1] - col_idx = torch.arange(t).unsqueeze(0).repeat(t, 1) - row_idx = torch.arange(t).unsqueeze(1).repeat(1, t) - mask = (row_idx < col_idx).to(input_t.dtype) * large_negative_number - return ( - mask.unsqueeze(0).unsqueeze(0).to(input_t.device) - ) # Equivalent to jnp.newaxis - - -def merge_masks(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: - """Merges 2 masks. - - logscale mask is expected but 0/1 mask is also fine. - - Args: - a: torch.Tensor of shape [1|B, 1, 1|T, S]. - b: torch.Tensor of shape [1|B, 1, 1|T, S]. - - Returns: - torch.Tensor of shape [1|B, 1, 1|T, S]. - """ - - def expand_t(key_mask): - query_mask = key_mask.transpose(-1, -2) # Equivalent of jnp.transpose - return torch.minimum(query_mask, key_mask) - - if a.shape[2] != b.shape[2]: - if a.shape[2] == 1: - a = expand_t(a) - else: - assert b.shape[2] == 1 - b = expand_t(b) - - assert a.shape[1:] == b.shape[1:], f"a.shape={a.shape}, b.shape={b.shape}." - return torch.minimum(a, b) # Element-wise minimum, similar to jnp.minimum - - -def process_group(key, group, value_name, forecast_context_len): - group = group.tail(forecast_context_len) - return np.array(group[value_name], dtype=np.float32), key - - -def moving_average(arr, window_size): - """Calculates the moving average using NumPy's convolution function.""" - # Pad with zeros to handle initial window positions - arr_padded = np.pad(arr, (window_size - 1, 0), "constant") - smoothed_arr = np.convolve(arr_padded, np.ones(window_size), "valid") / window_size - return [smoothed_arr, arr - smoothed_arr] - - -def freq_map(freq: str): - """Returns the frequency map for the given frequency string.""" - freq = str.upper(freq) - if ( - freq.endswith("H") - or freq.endswith("T") - or freq.endswith("MIN") - or freq.endswith("D") - or freq.endswith("B") - or freq.endswith("U") - ): - return 0 - elif freq.endswith(("W", "M", "MS")): - return 1 - elif freq.endswith("Y") or freq.endswith("Q"): - return 2 - else: - raise ValueError(f"Invalid frequency: {freq}") - - -class ResidualBlock(nn.Module): - """TimesFM residual block.""" - - def __init__( - self, - input_dims, - hidden_dims, - output_dims, - ): - super(ResidualBlock, self).__init__() - self.input_dims = input_dims - self.hidden_dims = hidden_dims - self.output_dims = output_dims - - # Hidden Layer - self.hidden_layer = nn.Sequential( - nn.Linear(input_dims, hidden_dims), - nn.SiLU(), - ) - - # Output Layer - self.output_layer = nn.Linear(hidden_dims, output_dims) - # Residual Layer - self.residual_layer = nn.Linear(input_dims, output_dims) - - def forward(self, x): - hidden = self.hidden_layer(x) - output = self.output_layer(hidden) - residual = self.residual_layer(x) - return output + residual - - -class RMSNorm(torch.nn.Module): - """Pax rms norm in pytorch.""" - - def __init__( - self, - dim: int, - eps: float = 1e-6, - add_unit_offset: bool = False, - ): - super().__init__() - self.eps = eps - self.add_unit_offset = add_unit_offset - self.weight = nn.Parameter(torch.zeros(dim)) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()) - if self.add_unit_offset: - output = output * (1 + self.weight.float()) - else: - output = output * self.weight.float() - return output.type_as(x) - - -class TransformerMLP(nn.Module): - """Pax transformer MLP in pytorch.""" - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - ): - super().__init__() - self.gate_proj = nn.Linear(hidden_size, intermediate_size) - self.down_proj = nn.Linear(intermediate_size, hidden_size) - self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6) - - def forward(self, x, paddings=None): - gate_inp = self.layer_norm(x) - gate = self.gate_proj(gate_inp) - gate = F.relu(gate) - outputs = self.down_proj(gate) - if paddings is not None: - outputs = outputs * (1.0 - paddings[:, :, None]) - return outputs + x - - -class TimesFMAttention(nn.Module): - """Implements the attention used in TimesFM.""" - - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - ): - super().__init__() - - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - self.hidden_size = hidden_size - self.head_dim = head_dim - - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = nn.Parameter( - torch.empty((self.head_dim,), dtype=torch.float32), - ) - - self.qkv_proj = nn.Linear( - self.hidden_size, - (self.num_heads + 2 * self.num_kv_heads) * self.head_dim, - ) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size) - - def _per_dim_scaling(self, query: torch.Tensor) -> torch.Tensor: - # [batch_size, n_local_heads, input_len, head_dim] - r_softplus_0 = 1.442695041 - softplus_func = torch.nn.Softplus() - scale = r_softplus_0 / math.sqrt(self.head_dim) - scale = scale * softplus_func(self.scaling) - return query * scale[None, None, None, :] - - def forward( - self, - hidden_states: torch.Tensor, - mask: torch.Tensor, - kv_write_indices: torch.Tensor | None = None, - kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, - ) -> torch.Tensor: - hidden_states_shape = hidden_states.shape - assert len(hidden_states_shape) == 3 - - print(">>> TimesFMAttention hidden_states_shape", hidden_states_shape) - batch_size, input_len, _ = hidden_states_shape - - qkv = self.qkv_proj(hidden_states) - xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - xq = xq.view(batch_size, -1, self.num_heads, self.head_dim) - xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim) - xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim) - xq = self._per_dim_scaling(xq) - - # Write new kv cache. - # [batch_size, input_len, n_local_kv_heads, head_dim] - if kv_cache is not None and kv_write_indices is not None: - k_cache, v_cache = kv_cache - k_cache.index_copy_(1, kv_write_indices, xk) - v_cache.index_copy_(1, kv_write_indices, xv) - - key = k_cache - value = v_cache - else: - key = xk - value = xv - if self.num_kv_heads != self.num_heads: - # [batch_size, max_seq_len, n_local_heads, head_dim] - key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) - value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2) - - # [batch_size, n_local_heads, input_len, head_dim] - q = xq.transpose(1, 2) - # [batch_size, n_local_heads, max_seq_len, head_dim] - k = key.transpose(1, 2) - v = value.transpose(1, 2) - - # [batch_size, n_local_heads, input_len, max_seq_len] - scores = torch.matmul(q, k.transpose(2, 3)) - scores = scores + mask - scores = F.softmax(scores.float(), dim=-1).type_as(q) - - # [batch_size, n_local_heads, input_len, head_dim] - output = torch.matmul(scores, v) - # return scores, output.transpose(1, 2).contiguous() - - # [batch_size, input_len, hidden_dim] - output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) - output = self.o_proj(output) - return scores, output - - -class TimesFMDecoderLayer(nn.Module): - """Transformer layer.""" - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - rms_norm_eps: float = 1e-6, - ): - super().__init__() - self.self_attn = TimesFMAttention( - hidden_size=hidden_size, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - ) - self.mlp = TransformerMLP( - hidden_size=hidden_size, - intermediate_size=intermediate_size, - ) - self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - mask: torch.Tensor, - paddings: torch.Tensor, - kv_write_indices: torch.Tensor | None = None, - kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, - ) -> torch.Tensor: - # Self Attention - print(">>> TimesFMDecoderLayer hidden_states", hidden_states.shape) - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - scores, hidden_states = self.self_attn( - hidden_states=hidden_states, - mask=mask, - kv_write_indices=kv_write_indices, - kv_cache=kv_cache, - ) - hidden_states = residual + hidden_states - - # MLP - hidden_states = self.mlp(hidden_states, paddings=paddings) - - return scores, hidden_states - - -class StackedDecoder(nn.Module): - """Stacked transformer layer.""" - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - num_layers: int, - rms_norm_eps: float = 1e-6, - ): - super().__init__() - - self.layers = nn.ModuleList() - for _ in range(num_layers): - self.layers.append( - TimesFMDecoderLayer( - hidden_size=hidden_size, - intermediate_size=intermediate_size, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - rms_norm_eps=rms_norm_eps, - ) - ) - - def forward( - self, - hidden_states: torch.Tensor, - paddings: torch.Tensor, - kv_write_indices: torch.Tensor | None = None, - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] | None = None, - output_attentions: bool = False, - output_hidden_states: bool = False, - ) -> torch.Tensor: - print(">>> StackedDecoder hidden_states", hidden_states.shape) - padding_mask = convert_paddings_to_mask(paddings, hidden_states.dtype) - atten_mask = causal_mask(hidden_states) - mask = merge_masks(padding_mask, atten_mask) - all_attentions = [] - all_hidden_states = [] - - for i in range(len(self.layers)): - layer = self.layers[i] - kv_cache = kv_caches[i] if kv_caches is not None else None - scores, hidden_states = layer( - hidden_states=hidden_states, - mask=mask, - paddings=paddings, - kv_write_indices=kv_write_indices, - kv_cache=kv_cache, - ) - if output_attentions: - all_attentions.append(scores) - if output_hidden_states: - all_hidden_states.append(hidden_states) - - return hidden_states, all_attentions, all_hidden_states - - -class PositionalEmbedding(torch.nn.Module): - """Generates position embedding for a given 1-d sequence. - - Attributes: - min_timescale: Start of the geometric index. Determines the periodicity of - the added signal. - max_timescale: End of the geometric index. Determines the frequency of the - added signal. - embedding_dims: Dimension of the embedding to be generated. - """ - - def __init__( - self, - embedding_dims: int, - min_timescale: int = 1, - max_timescale: int = 10_000, - ) -> None: - super().__init__() - self.min_timescale = min_timescale - self.max_timescale = max_timescale - self.embedding_dims = embedding_dims - - def forward(self, seq_length=None, position=None): - """Generates a Tensor of sinusoids with different frequencies. - - Args: - seq_length: an optional Python int defining the output sequence length. - if the `position` argument is specified. - position: [B, seq_length], optional position for each token in the - sequence, only required when the sequence is packed. - - Returns: - [B, seqlen, D] if `position` is specified, else [1, seqlen, D] - """ - if position is None: - assert seq_length is not None - # [1, seqlen] - position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0) - else: - assert position.ndim == 2, position.shape - - num_timescales = self.embedding_dims // 2 - log_timescale_increment = math.log( - float(self.max_timescale) / float(self.min_timescale) - ) / max(num_timescales - 1, 1) - inv_timescales = self.min_timescale * torch.exp( - torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment - ) - scaled_time = position.unsqueeze(2) * inv_timescales.unsqueeze(0).unsqueeze(0) - signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) - # Padding to ensure correct embedding dimension - signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) - return signal diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index c928cf9aca8e..c6d8f932730b 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -18,7 +18,6 @@ from typing import List import numpy as np -import torch from transformers import TimesFMConfig, is_torch_available from transformers.testing_utils import ( From eecdf4aee0cbfcad70746e1a0bec4881d8a39785 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 28 Nov 2024 20:32:02 +0100 Subject: [PATCH 025/242] add intermediate_size and initialize with config --- .../models/timesfm/configuration_timesfm.py | 10 ++- .../models/timesfm/modeling_timesfm.py | 79 ++++--------------- tests/models/timesfm/test_modeling_timesfm.py | 3 + 3 files changed, 26 insertions(+), 66 deletions(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 0ff463ba270d..26cf828f0da9 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -47,6 +47,8 @@ class TimesFMConfig(PretrainedConfig): Number of Transformer layers. model_dim (`int`, *optional*, defaults to 1280): Size of the hidden layers in the feed-forward networks. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. head_dim (`int`, *optional*, defaults to 80): Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will be defined as `num_heads * head_dim`. @@ -69,12 +71,14 @@ class TimesFMConfig(PretrainedConfig): initializer_factor (`float`, *optional*, defaults to 1.0): A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing). + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. """ model_type = "timesfm" keys_to_ignore_at_inference = [] attribute_map = { - "hidden_size": "hidden_size", + "hidden_size": "model_dim", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers", } @@ -88,6 +92,7 @@ def __init__( freq_size: int = 3, num_layers: int = 20, model_dim: int = 1280, + intermediate_size: int = 1280, head_dim: int = 80, num_heads: int = 16, dropout_rate: float = 0.1, @@ -98,6 +103,7 @@ def __init__( use_positional_embedding: bool = True, batch_size: int = 32, initializer_factor: float = 1.0, + attention_dropout: float = 0.0, **kwargs, ): self.patch_len = patch_len @@ -107,6 +113,7 @@ def __init__( self.pad_val = pad_val self.freq_size = freq_size self.model_dim = model_dim + self.intermediate_size = intermediate_size self.head_dim = head_dim self.num_layers = num_layers self.num_heads = num_heads @@ -116,6 +123,7 @@ def __init__( self.use_positional_embedding = use_positional_embedding self.batch_size = batch_size self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout super().__init__( is_encoder_decoder=self.is_encoder_decoder, diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 9e0959b8e12f..7abfbb9f5ee3 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -181,23 +181,16 @@ def forward(self, seq_length=None, position=None): class TimesFMAttention(nn.Module): """Implements the attention used in TimesFM.""" - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - ): + def __init__(self, config: TimesFMConfig): super().__init__() - - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads + self.num_heads = config.num_heads + self.num_kv_heads = config.num_heads assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - self.hidden_size = hidden_size - self.head_dim = head_dim + self.hidden_size = config.model_dim + self.head_dim = config.head_dim self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -274,33 +267,17 @@ def forward( # [batch_size, input_len, hidden_dim] output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) output = self.o_proj(output) - return scores, output + return output, scores class TimesFMDecoderLayer(nn.Module): """Transformer layer.""" - def __init__( - self, - hidden_size: int, - intermediate_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - rms_norm_eps: float = 1e-6, - ): + def __init__(self, config: TimesFMConfig): super().__init__() - self.self_attn = TimesFMAttention( - hidden_size=hidden_size, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - ) - self.mlp = TimesFMTransformerMLP( - hidden_size=hidden_size, - intermediate_size=intermediate_size, - ) - self.input_layernorm = TimesFMRMSNorm(hidden_size, eps=rms_norm_eps) + self.self_attn = TimesFMAttention(config) + self.mlp = TimesFMTransformerMLP(config.model_dim, config.intermediate_size) + self.input_layernorm = TimesFMRMSNorm(config.model_dim, eps=config.rms_norm_eps) def forward( self, @@ -313,7 +290,7 @@ def forward( # Self Attention residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - scores, hidden_states = self.self_attn( + hidden_states, scores = self.self_attn( hidden_states=hidden_states, mask=mask, kv_write_indices=kv_write_indices, @@ -330,30 +307,10 @@ def forward( class TimesFMStackedDecoder(nn.Module): """Stacked transformer layer.""" - def __init__( - self, - hidden_size: int, - intermediate_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - num_layers: int, - rms_norm_eps: float = 1e-6, - ): + def __init__(self, config: TimesFMConfig): super().__init__() - self.layers = nn.ModuleList() - for _ in range(num_layers): - self.layers.append( - TimesFMDecoderLayer( - hidden_size=hidden_size, - intermediate_size=intermediate_size, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - rms_norm_eps=rms_norm_eps, - ) - ) + self.layers = nn.ModuleList([TimesFMDecoderLayer(config) for _ in range(config.num_layers)]) def forward( self, @@ -602,15 +559,7 @@ def __init__(self, config: TimesFMConfig): output_dims=config.horizon_len * (1 + len(config.quantiles)), hidden_dims=config.model_dim, ) - self.stacked_transformer = TimesFMStackedDecoder( - hidden_size=self.config.model_dim, - intermediate_size=self.config.model_dim, - num_heads=self.config.num_heads, - num_kv_heads=self.config.num_heads, - head_dim=self.config.head_dim, - num_layers=self.config.num_layers, - rms_norm_eps=self.config.rms_norm_eps, - ) + self.stacked_transformer = TimesFMStackedDecoder(config=config) if self.config.use_positional_embedding: self.position_emb = TimesFMPositionalEmbedding( embedding_dims=self.config.model_dim, diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index c6d8f932730b..394bea5e1b3b 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -56,6 +56,7 @@ def __init__( freq_size: int = 3, num_layers: int = 4, model_dim: int = 128, + intermediate_size: int = 1280, head_dim: int = 16, num_heads: int = 4, dropout_rate: float = 0.1, @@ -76,6 +77,7 @@ def __init__( self.pad_val = pad_val self.freq_size = freq_size self.model_dim = model_dim + self.intermediate_size = intermediate_size self.head_dim = head_dim self.num_hidden_layers = num_layers self.num_attention_heads = num_heads @@ -103,6 +105,7 @@ def get_config(self): pad_val=self.pad_val, freq_size=self.freq_size, model_dim=self.model_dim, + intermediate_size=self.intermediate_size, head_dim=self.head_dim, num_layers=self.num_hidden_layers, num_heads=self.num_attention_heads, From f5399657493114fcac574eaa0488278b43d82c47 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 28 Nov 2024 20:37:58 +0100 Subject: [PATCH 026/242] initial documentation --- docs/source/en/model_doc/timesfm.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/en/model_doc/timesfm.md b/docs/source/en/model_doc/timesfm.md index 9acc824f9e0f..1ae971603246 100644 --- a/docs/source/en/model_doc/timesfm.md +++ b/docs/source/en/model_doc/timesfm.md @@ -18,19 +18,19 @@ rendered properly in your Markdown viewer. ## Overview -The TimesFM model was proposed in []() by . - +TimesFM (Time Series Foundation Model) is a pretrained time-series foundation model proposed in [A decoder-only foundation model for time-series forecasting](https://huggingface.co/papers/2310.10688) by Abhimanyu Das, Weihao Kong, Rajat Sen, and Yichen Zhou. It is a decoder only model that uses non-overlapping patches of time-series data as input and outputs some output patch length prediction in an autoregressive fashion. + The abstract from the paper is the following: -** +*Motivated by recent advances in large language models for Natural Language Processing (NLP), we design a time-series foundation model for forecasting whose out-of-the-box zero-shot performance on a variety of public datasets comes close to the accuracy of state-of-the-art supervised forecasting models for each individual dataset. Our model is based on pretraining a patched-decoder style attention model on a large time-series corpus, and can work well across different forecasting history lengths, prediction lengths and temporal granularities.* Tips: This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). -The original code can be found [here](). +The original code can be found [here](https://github.com/google-research/timesfm). ## TimesFMConfig From f95d6eeba291342e66dae66fd0b09505414bd4a9 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 28 Nov 2024 20:49:06 +0100 Subject: [PATCH 027/242] rename mask to attention_mask --- src/transformers/models/timesfm/modeling_timesfm.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 7abfbb9f5ee3..ca3b41caf6e5 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -215,7 +215,7 @@ def _per_dim_scaling(self, query: torch.Tensor) -> torch.Tensor: def forward( self, hidden_states: torch.Tensor, - mask: torch.Tensor, + attention_mask: torch.Tensor | None = None, kv_write_indices: torch.Tensor | None = None, kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: @@ -257,7 +257,10 @@ def forward( # [batch_size, n_local_heads, input_len, max_seq_len] scores = torch.matmul(q, k.transpose(2, 3)) - scores = scores + mask + + if attention_mask is not None: + scores = scores + attention_mask + scores = F.softmax(scores.float(), dim=-1).type_as(q) # [batch_size, n_local_heads, input_len, head_dim] @@ -282,7 +285,7 @@ def __init__(self, config: TimesFMConfig): def forward( self, hidden_states: torch.Tensor, - mask: torch.Tensor, + attention_mask: torch.Tensor, paddings: torch.Tensor, kv_write_indices: torch.Tensor | None = None, kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, @@ -292,7 +295,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) hidden_states, scores = self.self_attn( hidden_states=hidden_states, - mask=mask, + attention_mask=attention_mask, kv_write_indices=kv_write_indices, kv_cache=kv_cache, ) @@ -332,7 +335,7 @@ def forward( kv_cache = kv_caches[i] if kv_caches is not None else None scores, hidden_states = layer( hidden_states=hidden_states, - mask=mask, + attention_mask=mask, paddings=paddings, kv_write_indices=kv_write_indices, kv_cache=kv_cache, From d475529b44139ba6e6c1cf53ebd49c011a0a5fbd Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 28 Nov 2024 21:25:02 +0100 Subject: [PATCH 028/242] smaller tests --- .../models/timesfm/modeling_timesfm.py | 27 +++++++++++++------ tests/models/timesfm/test_modeling_timesfm.py | 10 +++---- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index ca3b41caf6e5..a28f9dea50e8 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -218,7 +218,8 @@ def forward( attention_mask: torch.Tensor | None = None, kv_write_indices: torch.Tensor | None = None, kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, - ) -> torch.Tensor: + output_attentions: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: hidden_states_shape = hidden_states.shape assert len(hidden_states_shape) == 3 @@ -270,6 +271,10 @@ def forward( # [batch_size, input_len, hidden_dim] output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) output = self.o_proj(output) + + if output_attentions: + scores = None + return output, scores @@ -323,7 +328,7 @@ def forward( kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] | None = None, output_attentions: bool = False, output_hidden_states: bool = False, - ) -> torch.Tensor: + ) -> BaseModelOutput: padding_mask = timesfm_convert_paddings_to_mask(paddings, hidden_states.dtype) atten_mask = timesfm_causal_mask(hidden_states) mask = timesfm_merge_masks(padding_mask, atten_mask) @@ -345,7 +350,11 @@ def forward( if output_hidden_states: all_hidden_states.append(hidden_states) - return hidden_states, all_attentions, all_hidden_states + return BaseModelOutput( + last_hidden_state=hidden_states, + attentions=all_attentions, + hidden_states=all_hidden_states, + ) # Move utility functions here @@ -664,7 +673,7 @@ def forward( freq: torch.Tensor, output_attentions: bool = False, output_hidden_states: bool = False, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: num_outputs = len(self.config.quantiles) + 1 model_input, patched_padding, stats, _ = self._preprocess_input( input_ts=input_ts, @@ -673,17 +682,19 @@ def forward( f_emb = self.freq_emb(freq) # B x 1 x D model_input += f_emb - model_output, all_attentions, all_hidden_states = self.stacked_transformer( + transformer_output = self.stacked_transformer( model_input, patched_padding, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if output_hidden_states: - all_hidden_states = [model_input] + all_hidden_states + all_hidden_states = [model_input] + transformer_output.hidden_states + else: + all_hidden_states = None - output_ts = self._postprocess_output(model_output, num_outputs, stats) - return output_ts, all_attentions, all_hidden_states + output_ts = self._postprocess_output(transformer_output.last_hidden_state, num_outputs, stats) + return output_ts, transformer_output.attentions, all_hidden_states def decode( self, diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index 394bea5e1b3b..2b303d12d1ac 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -54,11 +54,11 @@ def __init__( context_len: int = 512, horizon_len: int = 128, freq_size: int = 3, - num_layers: int = 4, - model_dim: int = 128, - intermediate_size: int = 1280, - head_dim: int = 16, - num_heads: int = 4, + num_layers: int = 1, + model_dim: int = 16, + intermediate_size: int = 32, + head_dim: int = 2, + num_heads: int = 2, dropout_rate: float = 0.1, tolerance: float = 1e-6, rms_norm_eps: float = 1e-6, From f7c1fe0074f95c260ab22738986f14ccd773dfdf Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 28 Nov 2024 21:36:03 +0100 Subject: [PATCH 029/242] fixup --- src/transformers/models/auto/configuration_auto.py | 4 ++-- src/transformers/models/auto/modeling_auto.py | 4 ++-- tests/models/timesfm/test_modeling_timesfm.py | 5 +---- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 2bc9a5057530..bfadb29ac74f 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -267,10 +267,10 @@ ("swinv2", "Swinv2Config"), ("switch_transformers", "SwitchTransformersConfig"), ("t5", "T5Config"), - ("timesfm", "TimesFMConfig"), ("table-transformer", "TableTransformerConfig"), ("tapas", "TapasConfig"), ("time_series_transformer", "TimeSeriesTransformerConfig"), + ("timesfm", "TimesFMConfig"), ("timesformer", "TimesformerConfig"), ("timm_backbone", "TimmBackboneConfig"), ("trajectory_transformer", "TrajectoryTransformerConfig"), @@ -585,12 +585,12 @@ ("swinv2", "Swin Transformer V2"), ("switch_transformers", "SwitchTransformers"), ("t5", "T5"), - ("timesfm", "TimesFM"), ("t5v1.1", "T5v1.1"), ("table-transformer", "Table Transformer"), ("tapas", "TAPAS"), ("tapex", "TAPEX"), ("time_series_transformer", "Time Series Transformer"), + ("timesfm", "TimesFM"), ("timesformer", "TimeSformer"), ("timm_backbone", "TimmBackbone"), ("trajectory_transformer", "Trajectory Transformer"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 6237eea9e727..d71cb068f2e9 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -246,10 +246,10 @@ ("swinv2", "Swinv2Model"), ("switch_transformers", "SwitchTransformersModel"), ("t5", "T5Model"), - ("timesfm", "TimesFMModel"), ("table-transformer", "TableTransformerModel"), ("tapas", "TapasModel"), ("time_series_transformer", "TimeSeriesTransformerModel"), + ("timesfm", "TimesFMModel"), ("timesformer", "TimesformerModel"), ("timm_backbone", "TimmBackbone"), ("trajectory_transformer", "TrajectoryTransformerModel"), @@ -355,8 +355,8 @@ ("squeezebert", "SqueezeBertForMaskedLM"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), - ("timesfm", "TimesFMModel"), ("tapas", "TapasForMaskedLM"), + ("timesfm", "TimesFMModel"), ("transfo-xl", "TransfoXLLMHeadModel"), ("tvlt", "TvltForPreTraining"), ("unispeech", "UniSpeechForPreTraining"), diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index 2b303d12d1ac..71700a1e8102 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -40,7 +40,6 @@ if is_torch_available(): - from transformers import ( TimesFMModel, ) @@ -151,9 +150,7 @@ def prepare_config_and_inputs_for_common(self): @require_torch -class TimesFMModelTest( - ModelTesterMixin, unittest.TestCase -): +class TimesFMModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (TimesFMModel,) if is_torch_available() else () all_generative_model_classes = (TimesFMModel,) if is_torch_available() else () all_parallelizable_model_classes = () From 5a808be03b33dd9fae4149b339661d172a3f1a53 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 28 Nov 2024 21:37:42 +0100 Subject: [PATCH 030/242] fix copies --- docs/source/en/index.md | 1 + .../models/timesfm/configuration_timesfm.py | 2 +- src/transformers/utils/dummy_pt_objects.py | 14 ++++++++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 8a9ccf45b69c..783115f3d442 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -316,6 +316,7 @@ Flax), PyTorch, and/or TensorFlow. | [TAPAS](model_doc/tapas) | ✅ | ✅ | ❌ | | [TAPEX](model_doc/tapex) | ✅ | ✅ | ✅ | | [Time Series Transformer](model_doc/time_series_transformer) | ✅ | ❌ | ❌ | +| [TimesFM](model_doc/timesfm) | ✅ | ❌ | ❌ | | [TimeSformer](model_doc/timesformer) | ✅ | ❌ | ❌ | | [Trajectory Transformer](model_doc/trajectory_transformer) | ✅ | ❌ | ❌ | | [Transformer-XL](model_doc/transfo-xl) | ✅ | ✅ | ❌ | diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 26cf828f0da9..aa6a64e69bce 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -47,7 +47,7 @@ class TimesFMConfig(PretrainedConfig): Number of Transformer layers. model_dim (`int`, *optional*, defaults to 1280): Size of the hidden layers in the feed-forward networks. - intermediate_size (`int`, *optional*, defaults to 11008): + intermediate_size (`int`, *optional*, defaults to 1280): Dimension of the MLP representations. head_dim (`int`, *optional*, defaults to 80): Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 1238f058783c..df8969090e5d 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -8945,6 +8945,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class TimesFMModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class TimesFMPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class TimesformerForVideoClassification(metaclass=DummyObject): _backends = ["torch"] From 6dbbd803abb4caa067806534cbb2d1bab81432ae Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 29 Nov 2024 12:54:43 +0100 Subject: [PATCH 031/242] move to time series section --- docs/source/en/_toctree.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 0b802e7e44c0..061e4fa1a4d8 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -586,8 +586,6 @@ title: T5v1.1 - local: model_doc/tapex title: TAPEX - - local: model_doc/timesfm - title: TimesFM - local: model_doc/transfo-xl title: Transformer XL - local: model_doc/ul2 @@ -950,6 +948,8 @@ title: PatchTSMixer - local: model_doc/patchtst title: PatchTST + - local: model_doc/timesfm + title: TimesFM - local: model_doc/time_series_transformer title: Time Series Transformer title: Time series models From bfa9302b7a9ffb1d97737f3b047cbc41aa91608b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 29 Nov 2024 12:56:24 +0100 Subject: [PATCH 032/242] sort docs --- docs/source/en/_toctree.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 061e4fa1a4d8..c3652b1a9b94 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -948,10 +948,10 @@ title: PatchTSMixer - local: model_doc/patchtst title: PatchTST - - local: model_doc/timesfm - title: TimesFM - local: model_doc/time_series_transformer title: Time Series Transformer + - local: model_doc/timesfm + title: TimesFM title: Time series models - isExpanded: false sections: From eb5807e5f99e78fa986403d22afec40c098b1fd3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 29 Nov 2024 12:58:46 +0100 Subject: [PATCH 033/242] isort fix --- src/transformers/__init__.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 7991c23b7ed1..11fcfa4483da 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -767,13 +767,13 @@ "models.swinv2": ["Swinv2Config"], "models.switch_transformers": ["SwitchTransformersConfig"], "models.t5": ["T5Config"], - "models.timesfm": ["TimesFMConfig"], "models.table_transformer": ["TableTransformerConfig"], "models.tapas": [ "TapasConfig", "TapasTokenizer", ], "models.time_series_transformer": ["TimeSeriesTransformerConfig"], + "models.timesfm": ["TimesFMConfig"], "models.timesformer": ["TimesformerConfig"], "models.timm_backbone": ["TimmBackboneConfig"], "models.trocr": [ @@ -3475,12 +3475,6 @@ "load_tf_weights_in_t5", ] ) - _import_structure["models.timesfm"].extend( - [ - "TimesFMModel", - "TimesFMPreTrainedModel", - ] - ) _import_structure["models.table_transformer"].extend( [ "TableTransformerForObjectDetection", @@ -3505,6 +3499,12 @@ "TimeSeriesTransformerPreTrainedModel", ] ) + _import_structure["models.timesfm"].extend( + [ + "TimesFMModel", + "TimesFMPreTrainedModel", + ] + ) _import_structure["models.timesformer"].extend( [ "TimesformerForVideoClassification", From be32365c9378608e2cb620bc4cc1f382d1961f83 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 29 Nov 2024 13:06:34 +0100 Subject: [PATCH 034/242] batch_size is not a configuration --- .../models/timesfm/configuration_timesfm.py | 4 --- .../models/timesfm/modeling_timesfm.py | 29 ++----------------- tests/models/timesfm/test_modeling_timesfm.py | 3 -- 3 files changed, 3 insertions(+), 33 deletions(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index aa6a64e69bce..45691cd2f46a 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -66,8 +66,6 @@ class TimesFMConfig(PretrainedConfig): The value used to pad the predictions. use_positional_embedding (`bool`, *optional*, defaults to `True`): Whether to add positional embeddings. - batch_size (`int`, *optional*, defaults to 32): - The batch size. initializer_factor (`float`, *optional*, defaults to 1.0): A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing). @@ -101,7 +99,6 @@ def __init__( quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], pad_val: float = 1123581321.0, use_positional_embedding: bool = True, - batch_size: int = 32, initializer_factor: float = 1.0, attention_dropout: float = 0.0, **kwargs, @@ -121,7 +118,6 @@ def __init__( self.tolerance = tolerance self.rms_norm_eps = rms_norm_eps self.use_positional_embedding = use_positional_embedding - self.batch_size = batch_size self.initializer_factor = initializer_factor self.attention_dropout = attention_dropout diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index a28f9dea50e8..a2b2d6616d4c 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -788,14 +788,6 @@ def __init__(self, config: TimesFMConfig): self.context_len = config.context_len self.horizon_len = config.horizon_len - self.input_patch_len = config.patch_len - self.output_patch_len = config.horizon_len - self.num_layers = config.num_layers - self.model_dims = config.model_dim - self.quantiles = config.quantiles - self.num_heads = config.num_heads - self.batch_size = config.batch_size - self._horizon_start = self.context_len - self.input_patch_len # Initialize weights and apply final processing self.post_init() @@ -905,24 +897,9 @@ def forward( input_ts, input_padding, inp_freq = self._preprocess(inputs, freq) - input_ts_in = torch.from_numpy( - np.array( - input_ts, - dtype=np.float32, - ) - ) - input_padding_in = torch.from_numpy( - np.array( - input_padding, - dtype=np.float32, - ) - ) - inp_freq_in = torch.from_numpy( - np.array( - inp_freq, - dtype=np.int32, - ) - ).long() + input_ts_in = torch.from_numpy(np.array(input_ts, dtype=np.float32)) + input_padding_in = torch.from_numpy(np.array(input_padding, dtype=np.float32)) + inp_freq_in = torch.from_numpy(np.array(inp_freq, dtype=np.int32)).long() mean_outputs, full_outputs, last_hidden_state, all_attentions, all_hidden_states = self.decoder.decode( input_ts=input_ts_in, paddings=input_padding_in, diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index 71700a1e8102..8d43a5b1498b 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -64,7 +64,6 @@ def __init__( quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], pad_val: float = 1123581321.0, use_positional_embedding: bool = True, - batch_size: int = 32, initializer_factor: float = 0.0, is_training: bool = False, ): @@ -84,7 +83,6 @@ def __init__( self.tolerance = tolerance self.rms_norm_eps = rms_norm_eps self.use_positional_embedding = use_positional_embedding - self.batch_size = batch_size self.initializer_factor = initializer_factor self.is_training = is_training @@ -112,7 +110,6 @@ def get_config(self): tolerance=self.tolerance, rms_norm_eps=self.rms_norm_eps, use_positional_embedding=self.use_positional_embedding, - batch_size=self.batch_size, initializer_factor=self.initializer_factor, ) From c4a361088fb075906d7c9f13da1b4876e9efbe53 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 30 Nov 2024 18:28:31 +0100 Subject: [PATCH 035/242] rename to TimesFMModelForPrediction --- docs/source/en/model_doc/timesfm.md | 28 +-- src/transformers/__init__.py | 6 +- src/transformers/models/auto/modeling_auto.py | 4 +- src/transformers/models/timesfm/__init__.py | 8 +- .../models/timesfm/configuration_timesfm.py | 10 +- .../models/timesfm/modeling_timesfm.py | 216 ++++++++++-------- src/transformers/utils/dummy_pt_objects.py | 9 +- tests/models/timesfm/test_modeling_timesfm.py | 16 +- 8 files changed, 149 insertions(+), 148 deletions(-) diff --git a/docs/source/en/model_doc/timesfm.md b/docs/source/en/model_doc/timesfm.md index 1ae971603246..76cf1f8afef2 100644 --- a/docs/source/en/model_doc/timesfm.md +++ b/docs/source/en/model_doc/timesfm.md @@ -37,34 +37,14 @@ The original code can be found [here](https://github.com/google-research/timesfm [[autodoc]] TimesFMConfig -## TimesFMModel +## TimesFMDecoder -[[autodoc]] TimesFMModel +[[autodoc]] TimesFMDecoder - forward -## TimesFMForConditionalGeneration +## TimesFMModelForPrediction -[[autodoc]] TimesFMForConditionalGeneration - - forward - -## TimesFMEncoderModel - -[[autodoc]] TimesFMEncoderModel - - forward - -## TimesFMForSequenceClassification - -[[autodoc]] TimesFMForSequenceClassification - - forward - -## TimesFMForTokenClassification - -[[autodoc]] TimesFMForTokenClassification - - forward - -## TimesFMForQuestionAnswering - -[[autodoc]] TimesFMForQuestionAnswering +[[autodoc]] TimesFMModelForPrediction - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 11fcfa4483da..6fe3a85cdb3c 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3501,7 +3501,8 @@ ) _import_structure["models.timesfm"].extend( [ - "TimesFMModel", + "TimesFMModelForPrediction", + "TimesFMDecoder", "TimesFMPreTrainedModel", ] ) @@ -7993,7 +7994,8 @@ TimeSeriesTransformerPreTrainedModel, ) from .models.timesfm import ( - TimesFMModel, + TimesFMDecoder, + TimesFMModelForPrediction, TimesFMPreTrainedModel, ) from .models.timesformer import ( diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index d71cb068f2e9..f01a84e4e7a9 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -249,7 +249,7 @@ ("table-transformer", "TableTransformerModel"), ("tapas", "TapasModel"), ("time_series_transformer", "TimeSeriesTransformerModel"), - ("timesfm", "TimesFMModel"), + ("timesfm", "TimesFMModelForPrediction"), ("timesformer", "TimesformerModel"), ("timm_backbone", "TimmBackbone"), ("trajectory_transformer", "TrajectoryTransformerModel"), @@ -356,7 +356,7 @@ ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), ("tapas", "TapasForMaskedLM"), - ("timesfm", "TimesFMModel"), + ("timesfm", "TimesFMModelForPrediction"), ("transfo-xl", "TransfoXLLMHeadModel"), ("tvlt", "TvltForPreTraining"), ("unispeech", "UniSpeechForPreTraining"), diff --git a/src/transformers/models/timesfm/__init__.py b/src/transformers/models/timesfm/__init__.py index 6592a5b1620e..51028a860782 100644 --- a/src/transformers/models/timesfm/__init__.py +++ b/src/transformers/models/timesfm/__init__.py @@ -30,7 +30,8 @@ pass else: _import_structure["modeling_timesfm"] = [ - "TimesFMModel", + "TimesFMModelForPrediction", + "TimesFMDecoder", "TimesFMPreTrainedModel", ] @@ -43,10 +44,7 @@ except OptionalDependencyNotAvailable: pass else: - from .modeling_timesfm import ( - TimesFMModel, - TimesFMPreTrainedModel, - ) + from .modeling_timesfm import TimesFMDecoder, TimesFMModelForPrediction, TimesFMPreTrainedModel else: import sys diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 45691cd2f46a..eb135c03c968 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -26,7 +26,7 @@ class TimesFMConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`TimesFMModel`] or a [`TFTimesFMModel`]. It is used to + This is the configuration class to store the configuration of a [`TimesFMModelForPrediction`] or a [`TFTimesFMDecoder`]. It is used to instantiate a TimesFM model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the TimesFM [google/timesfm-1.0-200m](https://huggingface.co/google/timesfm-1.0-200m) architecture. @@ -54,8 +54,6 @@ class TimesFMConfig(PretrainedConfig): be defined as `num_heads * head_dim`. num_heads (`int`, *optional*, defaults to 16): Number of attention heads for each attention layer in the Transformer encoder. - dropout_rate (`float`, *optional*, defaults to 0.1): - The ratio for all dropout layers. tolerance (`float`, *optional*, defaults to 1e-06): The tolerance for the quantile loss. rms_norm_eps (`float`, *optional*, defaults to 1e-06): @@ -69,8 +67,6 @@ class TimesFMConfig(PretrainedConfig): initializer_factor (`float`, *optional*, defaults to 1.0): A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing). - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. """ model_type = "timesfm" @@ -93,14 +89,12 @@ def __init__( intermediate_size: int = 1280, head_dim: int = 80, num_heads: int = 16, - dropout_rate: float = 0.1, tolerance: float = 1e-6, rms_norm_eps: float = 1e-6, quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], pad_val: float = 1123581321.0, use_positional_embedding: bool = True, initializer_factor: float = 1.0, - attention_dropout: float = 0.0, **kwargs, ): self.patch_len = patch_len @@ -114,12 +108,10 @@ def __init__( self.head_dim = head_dim self.num_layers = num_layers self.num_heads = num_heads - self.dropout_rate = dropout_rate self.tolerance = tolerance self.rms_norm_eps = rms_norm_eps self.use_positional_embedding = use_positional_embedding self.initializer_factor = initializer_factor - self.attention_dropout = attention_dropout super().__init__( is_encoder_decoder=self.is_encoder_decoder, diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index a2b2d6616d4c..f2134e90e463 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -37,9 +37,15 @@ @dataclass -class TimesFMOutput(BaseModelOutput): - mean_predictions: np.ndarray = None - full_predictions: np.ndarray = None +class TimesFMDecoderOutput(BaseModelOutput): + loc: np.ndarray | None = None + scale: np.ndarray | None = None + + +@dataclass +class TimesFMOutputForPrediction(BaseModelOutput): + mean_predictions: np.ndarray | None = None + full_predictions: np.ndarray | None = None class TimesFMTransformerMLP(nn.Module): @@ -553,8 +559,8 @@ def _init_weights(self, module): pass -class PatchedTimeSeriesDecoder(TimesFMPreTrainedModel): - """Patched time-series decoder.""" +class TimesFMDecoder(TimesFMPreTrainedModel): + """Patched time-series decoder without any specific output layer.""" def __init__(self, config: TimesFMConfig): super().__init__(config) @@ -566,11 +572,6 @@ def __init__(self, config: TimesFMConfig): hidden_dims=config.model_dim, ) self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.model_dim) - self.horizon_ff_layer = TimesFMResidualBlock( - input_dims=config.model_dim, - output_dims=config.horizon_len * (1 + len(config.quantiles)), - hidden_dims=config.model_dim, - ) self.stacked_transformer = TimesFMStackedDecoder(config=config) if self.config.use_positional_embedding: self.position_emb = TimesFMPositionalEmbedding( @@ -600,11 +601,6 @@ def _forward_transform( ) return outputs, (mu, sigma) - def _reverse_transform(self, outputs: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: - """Output is of shape [B, N, P, Q].""" - mu, sigma = stats - return outputs * sigma[:, None, None, None] + mu[:, None, None, None] - def _preprocess_input( self, input_ts: torch.Tensor, @@ -649,23 +645,6 @@ def _preprocess_input( return model_input, patched_padding, stats, patched_inputs - def _postprocess_output( - self, - model_output: torch.Tensor, - num_outputs: int, - stats: tuple[torch.Tensor, torch.Tensor], - ) -> torch.Tensor: - """Postprocess output of stacked transformer.""" - - # B x N x (H.Q) - output_ts = self.horizon_ff_layer(model_output) - - # Reshape using view - b, n, _ = output_ts.shape - output_ts = output_ts.view(b, n, self.config.horizon_len, num_outputs) - - return self._reverse_transform(output_ts, stats) - def forward( self, input_ts: torch.Tensor, @@ -673,8 +652,7 @@ def forward( freq: torch.Tensor, output_attentions: bool = False, output_hidden_states: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: - num_outputs = len(self.config.quantiles) + 1 + ) -> TimesFMDecoderOutput: model_input, patched_padding, stats, _ = self._preprocess_input( input_ts=input_ts, input_padding=input_padding, @@ -693,8 +671,96 @@ def forward( else: all_hidden_states = None - output_ts = self._postprocess_output(transformer_output.last_hidden_state, num_outputs, stats) - return output_ts, transformer_output.attentions, all_hidden_states + return TimesFMDecoderOutput( + last_hidden_state=transformer_output.last_hidden_state, + hidden_states=all_hidden_states, + attentions=transformer_output.attentions if output_attentions else None, + loc=stats[0], + scale=stats[1], + ) + + +class TimesFMModelForPrediction(TimesFMPreTrainedModel): + def __init__(self, config: TimesFMConfig): + super().__init__(config) + + self.config = config + self.context_len = config.context_len + self.horizon_len = config.horizon_len + + self.decoder = TimesFMDecoder(config) + + # quantile and mean output + self.horizon_ff_layer = TimesFMResidualBlock( + input_dims=config.model_dim, + output_dims=config.horizon_len * (1 + len(config.quantiles)), + hidden_dims=config.model_dim, + ) + + # Initialize weights and apply final processing + self.post_init() + + def _preprocess(self, inputs: Sequence[np.array], freq: Sequence[int]) -> tuple[np.array, np.array, int]: + """Formats and pads raw inputs to feed into the model. + + This function both pads each time series to match the context length, and + pads the inputs to meet the SPMD shape requirement. + + Args: + inputs: A list of 1d Tensors. Each JTensor is the context time series of + a single forecast task. + freq: list of frequencies + + Returns: + A tuple of: + - the padded input time series to meet the model required context. + - the padding indicator. + - the number of padded examples for SPMD so that each core has the same + number (a multiple of `batch_size`) of examples. + """ + input_ts, input_padding, inp_freq = [], [], [] + + for i, ts in enumerate(inputs): + input_len = ts.shape[0] + padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float) + if input_len < self.context_len: + num_front_pad = self.context_len - input_len + ts = np.concatenate([np.zeros(shape=(num_front_pad,), dtype=float), ts], axis=0) + padding = np.concatenate([np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0) + elif input_len > self.context_len: + ts = ts[-self.context_len :] + padding = padding[-(self.context_len + self.horizon_len) :] + + input_ts.append(ts) + input_padding.append(padding) + inp_freq.append(freq[i]) + + return ( + np.stack(input_ts, axis=0), + np.stack(input_padding, axis=0), + np.array(inp_freq).astype(np.int32).reshape(-1, 1), + ) + + def _postprocess_output( + self, + model_output: torch.Tensor, + stats: tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + """Postprocess output of stacked transformer.""" + + # B x N x (H.Q) + output_ts = self.horizon_ff_layer(model_output) + + # Reshape using view + b, n, _ = output_ts.shape + output_ts = output_ts.view(b, n, self.config.horizon_len, len(self.config.quantiles) + 1) + + return self._reverse_transform(output_ts, stats) + + def _reverse_transform(self, outputs: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + """Output is of shape [B, N, P, Q].""" + mu, sigma = stats + return outputs * sigma[:, None, None, None] + mu[:, None, None, None] def decode( self, @@ -732,6 +798,7 @@ def decode( final_out = input_ts context_len = final_out.shape[1] full_outputs = [] + if paddings.shape[1] != final_out.shape[1] + horizon_len: raise ValueError( "Length of paddings must match length of input + horizon_len:" @@ -744,13 +811,18 @@ def decode( current_padding = paddings[:, 0 : final_out.shape[1]] input_ts = final_out[:, -max_len:] input_padding = current_padding[:, -max_len:] - fprop_outputs, all_attentions, all_hidden_states = self.forward( + decoder_output = self.decoder( input_ts, input_padding, freq, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) + fprop_outputs = self._postprocess_output( + decoder_output.last_hidden_state, + (decoder_output.loc, decoder_output.scale), + ) + if return_forecast_on_context and step_index == 0: # For the first decodings step, collect the model forecast on the # context except the unavailable first input batch forecast. @@ -775,62 +847,12 @@ def decode( # `full_outputs` indexing starts at the forecast horizon. full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0:horizon_len, :] - return full_outputs[:, :, 0], full_outputs, fprop_outputs, all_attentions, all_hidden_states - - -class TimesFMModel(TimesFMPreTrainedModel): - def __init__(self, config: TimesFMConfig): - super().__init__(config) - - self.config = config - - self.decoder = PatchedTimeSeriesDecoder(config) - - self.context_len = config.context_len - self.horizon_len = config.horizon_len - - # Initialize weights and apply final processing - self.post_init() - - def _preprocess(self, inputs: Sequence[np.array], freq: Sequence[int]) -> tuple[np.array, np.array, int]: - """Formats and pads raw inputs to feed into the model. - - This function both pads each time series to match the context length, and - pads the inputs to meet the SPMD shape requirement. - - Args: - inputs: A list of 1d JTensors. Each JTensor is the context time series of - a single forecast task. - freq: list of frequencies - - Returns: - A tuple of: - - the padded input time series to meet the model required context. - - the padding indicator. - - the number of padded examples for SPMD so that each core has the same - number (a multiple of `batch_size`) of examples. - """ - input_ts, input_padding, inp_freq = [], [], [] - - for i, ts in enumerate(inputs): - input_len = ts.shape[0] - padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float) - if input_len < self.context_len: - num_front_pad = self.context_len - input_len - ts = np.concatenate([np.zeros(shape=(num_front_pad,), dtype=float), ts], axis=0) - padding = np.concatenate([np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0) - elif input_len > self.context_len: - ts = ts[-self.context_len :] - padding = padding[-(self.context_len + self.horizon_len) :] - - input_ts.append(ts) - input_padding.append(padding) - inp_freq.append(freq[i]) - return ( - np.stack(input_ts, axis=0), - np.stack(input_padding, axis=0), - np.array(inp_freq).astype(np.int32).reshape(-1, 1), + full_outputs[:, :, 0], + full_outputs, + decoder_output.last_hidden_state, + decoder_output.attentions, + decoder_output.hidden_states, ) def forward( @@ -844,12 +866,12 @@ def forward( output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, - ) -> tuple[np.ndarray, np.ndarray]: + ) -> TimesFMOutputForPrediction: """Forecasts on a list of time series. Args: inputs: list of time series forecast contexts. Each context time series - should be in a format convertible to JTensor by `jnp.array`. + should be in a format convertible to Tensor. freq: frequency of each context time series. 0 for high frequency (default), 1 for medium, and 2 for low. Notice this is different from the `freq` required by `forecast_on_df`. @@ -862,7 +884,7 @@ def forward( have non-negative values. Returns: - A tuple for JTensors: + A tuple for Tensors: - the mean forecast of size (# inputs, # forecast horizon), - the full forecast (mean + quantiles) of size (# inputs, # forecast horizon, 1 + # quantiles). @@ -900,7 +922,7 @@ def forward( input_ts_in = torch.from_numpy(np.array(input_ts, dtype=np.float32)) input_padding_in = torch.from_numpy(np.array(input_padding, dtype=np.float32)) inp_freq_in = torch.from_numpy(np.array(inp_freq, dtype=np.int32)).long() - mean_outputs, full_outputs, last_hidden_state, all_attentions, all_hidden_states = self.decoder.decode( + mean_outputs, full_outputs, last_hidden_state, all_attentions, all_hidden_states = self.decode( input_ts=input_ts_in, paddings=input_padding_in, freq=inp_freq_in, @@ -918,7 +940,7 @@ def forward( full_outputs = torch.maximum(full_outputs, 0.0) if return_dict: - return TimesFMOutput( + return TimesFMOutputForPrediction( last_hidden_state=last_hidden_state, attentions=all_attentions if output_attentions else None, hidden_states=all_hidden_states if output_hidden_states else None, diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index df8969090e5d..88a1ad22f475 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -8945,7 +8945,14 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class TimesFMModel(metaclass=DummyObject): +class TimesFMDecoder(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class TimesFMModelForPrediction(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index 8d43a5b1498b..ec7550b043f3 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -40,9 +40,7 @@ if is_torch_available(): - from transformers import ( - TimesFMModel, - ) + from transformers import TimesFMModelForPrediction class TimesFMModelTester: @@ -66,6 +64,7 @@ def __init__( use_positional_embedding: bool = True, initializer_factor: float = 0.0, is_training: bool = False, + batch_size: int = 3, ): self.parent = parent self.patch_len = patch_len @@ -85,6 +84,7 @@ def __init__( self.use_positional_embedding = use_positional_embedding self.initializer_factor = initializer_factor self.is_training = is_training + self.batch_size = batch_size # The size of test input self.seq_length = context_len // patch_len @@ -148,8 +148,8 @@ def prepare_config_and_inputs_for_common(self): @require_torch class TimesFMModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (TimesFMModel,) if is_torch_available() else () - all_generative_model_classes = (TimesFMModel,) if is_torch_available() else () + all_model_classes = (TimesFMModelForPrediction,) if is_torch_available() else () + all_generative_model_classes = (TimesFMModelForPrediction,) if is_torch_available() else () all_parallelizable_model_classes = () fx_compatible = False test_pruning = False @@ -164,7 +164,7 @@ def setUp(self): def test_create_and_run_model(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model = TimesFMModel(config) + model = TimesFMModelForPrediction(config) model.to(torch_device) model.eval() results = model(**inputs_dict) @@ -180,7 +180,7 @@ def test_headmasking(self): # the main input name is `inputs` def test_model_main_input_name(self): - model_signature = inspect.signature(getattr(TimesFMModel, "forward")) + model_signature = inspect.signature(getattr(TimesFMModelForPrediction, "forward")) # The main input is the name of the argument after `self` observed_main_input_name = list(model_signature.parameters.keys())[1] - self.assertEqual(TimesFMModel.main_input_name, observed_main_input_name) + self.assertEqual(TimesFMModelForPrediction.main_input_name, observed_main_input_name) From 56a56062944d314d2df309e27e842f33130df3cc Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 4 Dec 2024 20:41:04 +0100 Subject: [PATCH 036/242] initial script --- .../convert_timesfm_orignal_to_pytorch.py | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py new file mode 100644 index 000000000000..bf186c06574c --- /dev/null +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py @@ -0,0 +1,76 @@ +import argparse +import os +import shutil + +import timesfm + +from transformers import TimesFMConfig, TimesFMModelForPrediction + + +""" +Sample usage: + +``` +python src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py \ + --output_dir /output/path +``` +""" + + +def write_model(model_path, safe_serialization=True): + os.makedirs(model_path, exist_ok=True) + tmp_model_path = os.path.join(model_path, "tmp") + os.makedirs(tmp_model_path, exist_ok=True) + + tfm = timesfm.TimesFm( + hparams=timesfm.TimesFmHparams( + backend="cpu", + per_core_batch_size=32, + horizon_len=128, + ), + checkpoint=timesfm.TimesFmCheckpoint( + huggingface_repo_id="google/timesfm-1.0-200m-pytorch"), + ) + + timesfm_config = TimesFMConfig( + patch_len=tfm.hparams.input_patch_len, + context_len=tfm.hparams.context_len, + horizon_len=tfm.hparams.horizon_len, + num_layers=tfm.hparams.num_layers, + model_dim=tfm.hparams.model_dims, + intermediate_size=tfm.hparams.model_dims, + head_dim=tfm.hparams.model_dims//tfm.hparams.num_heads, + num_heads=tfm.hparams.num_heads, + ) + timesfm_config.save_pretrained(tmp_model_path) + timesfm_model = TimesFMModelForPrediction(timesfm_config) + + # copy the weights from the original model to the new model making + import pdb; pdb.set_trace() + orignal_model = tfm._model + + + timesfm_model.load_state_dict(tfm.state_dict()) + timesfm_model.save_pretrained(model_path, safe_serialization=safe_serialization) + shutil.rmtree(tmp_model_path) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--output_dir", + required=True, + help="Location to write HF model and tokenizer", + ) + parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.") + args = parser.parse_args() + write_model( + model_path=args.output_dir, + safe_serialization=args.safe_serialization, + ) + + check_outputs(args.output_dir) + + +if __name__ == "__main__": + main() From 01756aef11d450e3ed76d7f207cdc65794754ff0 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 5 Dec 2024 09:18:12 +0100 Subject: [PATCH 037/242] add check_outputs --- .../convert_timesfm_orignal_to_pytorch.py | 173 ++++++++++++++++-- 1 file changed, 162 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py index bf186c06574c..ac638e9efe66 100644 --- a/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py @@ -2,7 +2,9 @@ import os import shutil +import numpy as np import timesfm +import torch from transformers import TimesFMConfig, TimesFMModelForPrediction @@ -23,13 +25,12 @@ def write_model(model_path, safe_serialization=True): os.makedirs(tmp_model_path, exist_ok=True) tfm = timesfm.TimesFm( - hparams=timesfm.TimesFmHparams( - backend="cpu", - per_core_batch_size=32, - horizon_len=128, - ), - checkpoint=timesfm.TimesFmCheckpoint( - huggingface_repo_id="google/timesfm-1.0-200m-pytorch"), + hparams=timesfm.TimesFmHparams( + backend="cpu", + per_core_batch_size=32, + horizon_len=128, + ), + checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id="google/timesfm-1.0-200m-pytorch"), ) timesfm_config = TimesFMConfig( @@ -39,22 +40,172 @@ def write_model(model_path, safe_serialization=True): num_layers=tfm.hparams.num_layers, model_dim=tfm.hparams.model_dims, intermediate_size=tfm.hparams.model_dims, - head_dim=tfm.hparams.model_dims//tfm.hparams.num_heads, + head_dim=tfm.hparams.model_dims // tfm.hparams.num_heads, num_heads=tfm.hparams.num_heads, ) timesfm_config.save_pretrained(tmp_model_path) timesfm_model = TimesFMModelForPrediction(timesfm_config) # copy the weights from the original model to the new model making - import pdb; pdb.set_trace() - orignal_model = tfm._model + original_model = tfm._model + + # Map decoder input_ff_layer + timesfm_model.decoder.input_ff_layer.hidden_layer[0].weight.data = original_model.input_ff_layer.hidden_layer[ + 0 + ].weight.data + timesfm_model.decoder.input_ff_layer.hidden_layer[0].bias.data = original_model.input_ff_layer.hidden_layer[ + 0 + ].bias.data + timesfm_model.decoder.input_ff_layer.output_layer.weight.data = ( + original_model.input_ff_layer.output_layer.weight.data + ) + timesfm_model.decoder.input_ff_layer.output_layer.bias.data = original_model.input_ff_layer.output_layer.bias.data + timesfm_model.decoder.input_ff_layer.residual_layer.weight.data = ( + original_model.input_ff_layer.residual_layer.weight.data + ) + timesfm_model.decoder.input_ff_layer.residual_layer.bias.data = ( + original_model.input_ff_layer.residual_layer.bias.data + ) + # Map freq embedding + timesfm_model.decoder.freq_emb.weight.data = original_model.freq_emb.weight.data + + # Map horizon_ff_layer + timesfm_model.horizon_ff_layer.hidden_layer[0].weight.data = original_model.horizon_ff_layer.hidden_layer[ + 0 + ].weight.data + timesfm_model.horizon_ff_layer.hidden_layer[0].bias.data = original_model.horizon_ff_layer.hidden_layer[ + 0 + ].bias.data + timesfm_model.horizon_ff_layer.output_layer.weight.data = original_model.horizon_ff_layer.output_layer.weight.data + timesfm_model.horizon_ff_layer.output_layer.bias.data = original_model.horizon_ff_layer.output_layer.bias.data + timesfm_model.horizon_ff_layer.residual_layer.weight.data = ( + original_model.horizon_ff_layer.residual_layer.weight.data + ) + timesfm_model.horizon_ff_layer.residual_layer.bias.data = original_model.horizon_ff_layer.residual_layer.bias.data + + # Map transformer layers + for i in range(len(timesfm_model.decoder.stacked_transformer.layers)): + # Map attention layers + timesfm_model.decoder.stacked_transformer.layers[ + i + ].self_attn.qkv_proj.weight.data = original_model.stacked_transformer.layers[i].self_attn.qkv_proj.weight.data + timesfm_model.decoder.stacked_transformer.layers[ + i + ].self_attn.qkv_proj.bias.data = original_model.stacked_transformer.layers[i].self_attn.qkv_proj.bias.data + timesfm_model.decoder.stacked_transformer.layers[ + i + ].self_attn.o_proj.weight.data = original_model.stacked_transformer.layers[i].self_attn.o_proj.weight.data + timesfm_model.decoder.stacked_transformer.layers[ + i + ].self_attn.o_proj.bias.data = original_model.stacked_transformer.layers[i].self_attn.o_proj.bias.data + timesfm_model.decoder.stacked_transformer.layers[ + i + ].self_attn.scaling.data = original_model.stacked_transformer.layers[i].self_attn.scaling.data + + # Map MLP layers + timesfm_model.decoder.stacked_transformer.layers[ + i + ].mlp.gate_proj.weight.data = original_model.stacked_transformer.layers[i].mlp.gate_proj.weight.data + timesfm_model.decoder.stacked_transformer.layers[ + i + ].mlp.gate_proj.bias.data = original_model.stacked_transformer.layers[i].mlp.gate_proj.bias.data + timesfm_model.decoder.stacked_transformer.layers[ + i + ].mlp.down_proj.weight.data = original_model.stacked_transformer.layers[i].mlp.down_proj.weight.data + timesfm_model.decoder.stacked_transformer.layers[ + i + ].mlp.down_proj.bias.data = original_model.stacked_transformer.layers[i].mlp.down_proj.bias.data + timesfm_model.decoder.stacked_transformer.layers[ + i + ].mlp.layer_norm.weight.data = original_model.stacked_transformer.layers[i].mlp.layer_norm.weight.data + timesfm_model.decoder.stacked_transformer.layers[ + i + ].mlp.layer_norm.bias.data = original_model.stacked_transformer.layers[i].mlp.layer_norm.bias.data + + # Map layer norms + timesfm_model.decoder.stacked_transformer.layers[ + i + ].input_layernorm.weight.data = original_model.stacked_transformer.layers[i].input_layernorm.weight.data - timesfm_model.load_state_dict(tfm.state_dict()) timesfm_model.save_pretrained(model_path, safe_serialization=safe_serialization) shutil.rmtree(tmp_model_path) +def check_outputs(model_path): + """Compares outputs between original and converted models.""" + print("\nChecking model outputs...") + + # Load original model + tfm = timesfm.TimesFm( + hparams=timesfm.TimesFmHparams( + backend="cpu", + per_core_batch_size=32, + horizon_len=128, + ), + checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id="google/timesfm-1.0-200m-pytorch"), + ) + + # Load converted model + converted_model = TimesFMModelForPrediction.from_pretrained(model_path) + converted_model.eval() # Set to evaluation mode + + # Create test inputs + forecast_input = [ + np.sin(np.linspace(0, 20, 100)), + np.sin(np.linspace(0, 20, 200)), + np.sin(np.linspace(0, 20, 400)), + ] + frequency_input = [0, 1, 2] + + # Get predictions from original model + point_forecast_orig, quantile_forecast_orig = tfm.forecast( + forecast_input, + freq=frequency_input, + ) + + # Get predictions from converted model + with torch.no_grad(): + outputs = converted_model(inputs=forecast_input, freq=frequency_input, return_dict=True) + point_forecast_conv = outputs.mean_predictions.numpy() + quantile_forecast_conv = outputs.full_predictions.numpy() + + # Compare outputs + point_forecast_diff = np.abs(point_forecast_orig - point_forecast_conv) + quantile_forecast_diff = np.abs(quantile_forecast_orig - quantile_forecast_conv) + + max_point_diff = point_forecast_diff.max() + mean_point_diff = point_forecast_diff.mean() + max_quantile_diff = quantile_forecast_diff.max() + mean_quantile_diff = quantile_forecast_diff.mean() + + print("\nOutput comparison:") + print(f"Point forecast - Max difference: {max_point_diff:.6f}") + print(f"Point forecast - Mean difference: {mean_point_diff:.6f}") + print(f"Quantile forecast - Max difference: {max_quantile_diff:.6f}") + print(f"Quantile forecast - Mean difference: {mean_quantile_diff:.6f}") + + # Define acceptable thresholds + POINT_THRESHOLD = 1e-5 + QUANTILE_THRESHOLD = 1e-5 + + if max_point_diff > POINT_THRESHOLD or max_quantile_diff > QUANTILE_THRESHOLD: + raise ValueError( + f"Output mismatch detected!\n" + f"Point forecast max diff: {max_point_diff} (threshold: {POINT_THRESHOLD})\n" + f"Quantile forecast max diff: {max_quantile_diff} (threshold: {QUANTILE_THRESHOLD})" + ) + + print("\n✓ All outputs match within acceptable tolerance!") + + # Optional: Print shapes for verification + print("\nOutput shapes:") + print(f"Original point forecast: {point_forecast_orig.shape}") + print(f"Converted point forecast: {point_forecast_conv.shape}") + print(f"Original quantile forecast: {quantile_forecast_orig.shape}") + print(f"Converted quantile forecast: {quantile_forecast_conv.shape}") + + def main(): parser = argparse.ArgumentParser() parser.add_argument( From 942c23cdf5eab49368a329341c298157d9bf04bc Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 5 Dec 2024 09:21:47 +0100 Subject: [PATCH 038/242] remove dropout_rate --- tests/models/timesfm/test_modeling_timesfm.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index ec7550b043f3..a6cdaf17846e 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -56,7 +56,6 @@ def __init__( intermediate_size: int = 32, head_dim: int = 2, num_heads: int = 2, - dropout_rate: float = 0.1, tolerance: float = 1e-6, rms_norm_eps: float = 1e-6, quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], @@ -78,7 +77,6 @@ def __init__( self.head_dim = head_dim self.num_hidden_layers = num_layers self.num_attention_heads = num_heads - self.dropout_rate = dropout_rate self.tolerance = tolerance self.rms_norm_eps = rms_norm_eps self.use_positional_embedding = use_positional_embedding @@ -106,7 +104,6 @@ def get_config(self): head_dim=self.head_dim, num_layers=self.num_hidden_layers, num_heads=self.num_attention_heads, - dropout_rate=self.dropout_rate, tolerance=self.tolerance, rms_norm_eps=self.rms_norm_eps, use_positional_embedding=self.use_positional_embedding, @@ -126,18 +123,10 @@ def prepare_config_and_inputs(self): config = self.get_config() - return ( - config, - forecast_input, - frequency_input, - ) + return (config, forecast_input, frequency_input) def prepare_config_and_inputs_for_common(self): - ( - config, - forecast_input, - frequency_input, - ) = self.prepare_config_and_inputs() + (config, forecast_input, frequency_input) = self.prepare_config_and_inputs() inputs_dict = { "inputs": forecast_input, From e7650bd6159adaf8a91b364cd267d08a0ec1c35a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 5 Dec 2024 09:53:09 +0100 Subject: [PATCH 039/242] works with torch.Tensor inputs --- .../convert_timesfm_orignal_to_pytorch.py | 10 +++- .../models/timesfm/modeling_timesfm.py | 56 +++++++++++-------- tests/models/timesfm/test_modeling_timesfm.py | 7 ++- 3 files changed, 46 insertions(+), 27 deletions(-) diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py index ac638e9efe66..9d1fd8afe7d0 100644 --- a/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py @@ -164,9 +164,13 @@ def check_outputs(model_path): freq=frequency_input, ) + # Convert inputs to sequence of tensors + forecast_input_tensor = [torch.tensor(ts, dtype=torch.float32) for ts in forecast_input] + frequency_input_tensor = torch.tensor(frequency_input, dtype=torch.long) + # Get predictions from converted model with torch.no_grad(): - outputs = converted_model(inputs=forecast_input, freq=frequency_input, return_dict=True) + outputs = converted_model(inputs=forecast_input_tensor, freq=frequency_input_tensor, return_dict=True) point_forecast_conv = outputs.mean_predictions.numpy() quantile_forecast_conv = outputs.full_predictions.numpy() @@ -213,7 +217,9 @@ def main(): required=True, help="Location to write HF model and tokenizer", ) - parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.") + parser.add_argument( + "--safe_serialization", type=bool, default=True, help="Whether or not to save using `safetensors`." + ) args = parser.parse_args() write_model( model_path=args.output_dir, diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index f2134e90e463..7b664ffd0527 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -24,7 +24,7 @@ import logging import math from dataclasses import dataclass -from typing import Any, List, Sequence, Tuple +from typing import List, Sequence, Tuple import numpy as np import torch @@ -452,10 +452,13 @@ def timesfm_shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Ten def timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]: - """Calculates the moving average using NumPy's convolution function.""" + """Calculates the moving average using PyTorch's convolution function.""" # Pad with zeros to handle initial window positions - arr_padded = np.pad(arr, (window_size - 1, 0), "constant") - smoothed_arr = np.convolve(arr_padded, np.ones(window_size), "valid") / window_size + arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) + # Create a convolution kernel + kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size + # Apply convolution to calculate the moving average + smoothed_arr = F.conv1d(arr_padded.unsqueeze(0).unsqueeze(0), kernel.unsqueeze(0).unsqueeze(0)).squeeze() return [smoothed_arr, arr - smoothed_arr] @@ -700,14 +703,16 @@ def __init__(self, config: TimesFMConfig): # Initialize weights and apply final processing self.post_init() - def _preprocess(self, inputs: Sequence[np.array], freq: Sequence[int]) -> tuple[np.array, np.array, int]: + def _preprocess( + self, inputs: Sequence[torch.Tensor], freq: Sequence[int] + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Formats and pads raw inputs to feed into the model. This function both pads each time series to match the context length, and pads the inputs to meet the SPMD shape requirement. Args: - inputs: A list of 1d Tensors. Each JTensor is the context time series of + inputs: A list of 1d Tensors. Each Tensor is the context time series of a single forecast task. freq: list of frequencies @@ -722,11 +727,11 @@ def _preprocess(self, inputs: Sequence[np.array], freq: Sequence[int]) -> tuple[ for i, ts in enumerate(inputs): input_len = ts.shape[0] - padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float) + padding = torch.zeros(input_len + self.horizon_len, dtype=torch.float32) if input_len < self.context_len: num_front_pad = self.context_len - input_len - ts = np.concatenate([np.zeros(shape=(num_front_pad,), dtype=float), ts], axis=0) - padding = np.concatenate([np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0) + ts = torch.cat([torch.zeros(num_front_pad, dtype=torch.float32), ts], dim=0) + padding = torch.cat([torch.ones(num_front_pad, dtype=torch.float32), padding], dim=0) elif input_len > self.context_len: ts = ts[-self.context_len :] padding = padding[-(self.context_len + self.horizon_len) :] @@ -736,9 +741,9 @@ def _preprocess(self, inputs: Sequence[np.array], freq: Sequence[int]) -> tuple[ inp_freq.append(freq[i]) return ( - np.stack(input_ts, axis=0), - np.stack(input_padding, axis=0), - np.array(inp_freq).astype(np.int32).reshape(-1, 1), + torch.stack(input_ts, dim=0), + torch.stack(input_padding, dim=0), + torch.tensor(inp_freq, dtype=torch.int32).reshape(-1, 1), ) def _postprocess_output( @@ -857,8 +862,8 @@ def decode( def forward( self, - inputs: Sequence[Any], - freq: Sequence[int] | None = None, + inputs: Sequence[torch.Tensor], + freq: Sequence[torch.Tensor | int] | None = None, window_size: int | None = None, forecast_context_len: int | None = None, return_forecast_on_context: bool = False, @@ -899,8 +904,13 @@ def forward( fcontext_len = self.context_len else: fcontext_len = forecast_context_len - inputs = [np.array(ts)[-fcontext_len:] for ts in inputs] - inp_min = np.min([np.min(ts) for ts in inputs]) + + # Get device from first input tensor + device = inputs[0].device + + # Truncate inputs to forecast_context_len + inputs = [ts[-fcontext_len:] for ts in inputs] + inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) if window_size is not None: new_inputs = [] @@ -919,13 +929,15 @@ def forward( input_ts, input_padding, inp_freq = self._preprocess(inputs, freq) - input_ts_in = torch.from_numpy(np.array(input_ts, dtype=np.float32)) - input_padding_in = torch.from_numpy(np.array(input_padding, dtype=np.float32)) - inp_freq_in = torch.from_numpy(np.array(inp_freq, dtype=np.int32)).long() + # Move tensors to the same device as input + input_ts = input_ts.to(device) + input_padding = input_padding.to(device) + inp_freq = inp_freq.to(device) + mean_outputs, full_outputs, last_hidden_state, all_attentions, all_hidden_states = self.decode( - input_ts=input_ts_in, - paddings=input_padding_in, - freq=inp_freq_in, + input_ts=input_ts, + paddings=input_padding, + freq=inp_freq, horizon_len=self.horizon_len, return_forecast_on_context=return_forecast_on_context, output_attentions=output_attentions, diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index a6cdaf17846e..7fd8f01ff247 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -18,6 +18,7 @@ from typing import List import numpy as np +import torch from transformers import TimesFMConfig, is_torch_available from transformers.testing_utils import ( @@ -115,9 +116,9 @@ def get_pipeline_config(self): def prepare_config_and_inputs(self): forecast_input = [ - np.sin(np.linspace(0, 20, 100)), - np.sin(np.linspace(0, 20, 200)), - np.sin(np.linspace(0, 20, 400)), + torch.tensor(np.sin(np.linspace(0, 20, 100)), dtype=torch.float32), + torch.tensor(np.sin(np.linspace(0, 20, 200)), dtype=torch.float32), + torch.tensor(np.sin(np.linspace(0, 20, 400)), dtype=torch.float32), ] frequency_input = [0, 1, 2] From c523f646fb1d56f8ea8c6a0aabe409472f529b16 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 5 Dec 2024 09:54:44 +0100 Subject: [PATCH 040/242] rename script --- ...m_orignal_to_pytorch.py => convert_timesfm_orignal_to_hf.py} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename src/transformers/models/timesfm/{convert_timesfm_orignal_to_pytorch.py => convert_timesfm_orignal_to_hf.py} (99%) diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py similarity index 99% rename from src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py rename to src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py index 9d1fd8afe7d0..eeed750c337b 100644 --- a/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py @@ -13,7 +13,7 @@ Sample usage: ``` -python src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py \ +python src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py \ --output_dir /output/path ``` """ From e64f562c1abe1cf3f5626aac135db13404231b64 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 5 Dec 2024 10:04:14 +0100 Subject: [PATCH 041/242] fix docstrings --- .../models/timesfm/modeling_timesfm.py | 47 ++++++++++++------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 7b664ffd0527..74779f6a1587 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -185,7 +185,7 @@ def forward(self, seq_length=None, position=None): class TimesFMAttention(nn.Module): - """Implements the attention used in TimesFM.""" + """Implements the attention used in TimesFM. One key diffrence is that there is _per_dim_scaling of the query.""" def __init__(self, config: TimesFMConfig): super().__init__() @@ -655,7 +655,8 @@ def forward( freq: torch.Tensor, output_attentions: bool = False, output_hidden_states: bool = False, - ) -> TimesFMDecoderOutput: + return_dict: bool = True, + ) -> TimesFMDecoderOutput | tuple[torch.Tensor, ...]: model_input, patched_padding, stats, _ = self._preprocess_input( input_ts=input_ts, input_padding=input_padding, @@ -674,13 +675,22 @@ def forward( else: all_hidden_states = None - return TimesFMDecoderOutput( - last_hidden_state=transformer_output.last_hidden_state, - hidden_states=all_hidden_states, - attentions=transformer_output.attentions if output_attentions else None, - loc=stats[0], - scale=stats[1], - ) + if return_dict: + return TimesFMDecoderOutput( + last_hidden_state=transformer_output.last_hidden_state, + hidden_states=all_hidden_states, + attentions=transformer_output.attentions if output_attentions else None, + loc=stats[0], + scale=stats[1], + ) + else: + return ( + transformer_output.last_hidden_state, + all_hidden_states, + transformer_output.attentions, + stats[0], + stats[1], + ) class TimesFMModelForPrediction(TimesFMPreTrainedModel): @@ -778,7 +788,7 @@ def decode( return_forecast_on_context: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, - ): + ) -> tuple[torch.Tensor, ...]: """Auto-regressive decoding without caching. Args: @@ -799,6 +809,9 @@ def decode( B x H' x (1 + # quantiles). In particular, if return_forecast_on_context is True, H' is H plus the forecastable context length, i.e. context_len - (first) patch_len. + + Raises: + ValueError: If the paddings do not match the input + horizon_len. """ final_out = input_ts context_len = final_out.shape[1] @@ -871,7 +884,7 @@ def forward( output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, - ) -> TimesFMOutputForPrediction: + ) -> TimesFMOutputForPrediction | tuple[torch.Tensor, ...]: """Forecasts on a list of time series. Args: @@ -887,15 +900,15 @@ def forward( when available, i.e. after the first input patch. truncate_negative: truncate to only non-negative values if all the contexts have non-negative values. + output_attentions: Whether to return the attentions. + output_hidden_states: Whether to return the hidden states. + return_dict: Whether to return a TimesFMOutputForPrediction object. Returns: - A tuple for Tensors: - - the mean forecast of size (# inputs, # forecast horizon), - - the full forecast (mean + quantiles) of size + A TimesFMOutputForPrediction object containing: + - the mean forecast of size (# inputs, # forecast horizon), + - the full forecast (mean + quantiles) of size (# inputs, # forecast horizon, 1 + # quantiles). - - Raises: - ValueError: If the checkpoint is not properly loaded. """ if return_dict is None: return_dict = self.config.use_return_dict From f5dbab95a07959f9252f183aa7b40ca0e781ee84 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 5 Dec 2024 10:27:16 +0100 Subject: [PATCH 042/242] fix freq when window_size is given --- .../models/timesfm/modeling_timesfm.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 74779f6a1587..b3de01cd3682 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -133,11 +133,11 @@ class TimesFMPositionalEmbedding(nn.Module): """Generates position embedding for a given 1-d sequence. Attributes: + embedding_dims: Dimension of the embedding to be generated. min_timescale: Start of the geometric index. Determines the periodicity of - the added signal. + the added signal. Defaults to 1. max_timescale: End of the geometric index. Determines the frequency of the - added signal. - embedding_dims: Dimension of the embedding to be generated. + added signal. Defaults to 10_000. """ def __init__( @@ -889,10 +889,9 @@ def forward( Args: inputs: list of time series forecast contexts. Each context time series - should be in a format convertible to Tensor. - freq: frequency of each context time series. 0 for high frequency - (default), 1 for medium, and 2 for low. Notice this is different from - the `freq` required by `forecast_on_df`. + should be a torch Tensor of potentially different context lengths. + freq: frequency of each context time series in the inputs. 0 for high frequency + (default), 1 for medium, and 2 for low. window_size: window size of trend + residual decomposition. If None then we do not do decomposition. forecast_context_len: optional max context length. @@ -927,9 +926,15 @@ def forward( if window_size is not None: new_inputs = [] - for ts in inputs: + if freq is not None: + new_freqs = [] + for i, ts in enumerate(inputs): new_inputs.extend(timesfm_moving_average(ts, window_size)) + if freq is not None: + new_freqs.extend([freq[i]] * 2) inputs = new_inputs + if freq is not None: + freq = new_freqs if freq is None: logging.info("No frequency provided via `freq`. Default to high (0).") From a8dcfa9cf53e57471c06bee52e9706c8d0dc0dfe Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 5 Dec 2024 10:48:17 +0100 Subject: [PATCH 043/242] add loss --- .../models/timesfm/modeling_timesfm.py | 33 +++++++++++++++---- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index b3de01cd3682..b93810265e3d 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -26,7 +26,6 @@ from dataclasses import dataclass from typing import List, Sequence, Tuple -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -38,14 +37,15 @@ @dataclass class TimesFMDecoderOutput(BaseModelOutput): - loc: np.ndarray | None = None - scale: np.ndarray | None = None + loc: torch.Tensor | None = None + scale: torch.Tensor | None = None @dataclass class TimesFMOutputForPrediction(BaseModelOutput): - mean_predictions: np.ndarray | None = None - full_predictions: np.ndarray | None = None + mean_predictions: torch.Tensor | None = None + full_predictions: torch.Tensor | None = None + loss: float | None = None class TimesFMTransformerMLP(nn.Module): @@ -873,11 +873,21 @@ def decode( decoder_output.hidden_states, ) + @staticmethod + def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor, quantiles: List[float]) -> torch.Tensor: + losses = [] + for q in quantiles: + errors = targets - predictions + loss = torch.max((q - 1) * errors, q * errors) + losses.append(loss.mean()) + return torch.stack(losses).mean() + def forward( self, inputs: Sequence[torch.Tensor], freq: Sequence[torch.Tensor | int] | None = None, window_size: int | None = None, + future_target: torch.Tensor | None = None, forecast_context_len: int | None = None, return_forecast_on_context: bool = False, truncate_negative: bool = False, @@ -894,6 +904,7 @@ def forward( (default), 1 for medium, and 2 for low. window_size: window size of trend + residual decomposition. If None then we do not do decomposition. + future_target: optional future target time series to be used for loss computation. forecast_context_len: optional max context length. return_forecast_on_context: True to return the forecast on the context when available, i.e. after the first input patch. @@ -908,6 +919,7 @@ def forward( - the mean forecast of size (# inputs, # forecast horizon), - the full forecast (mean + quantiles) of size (# inputs, # forecast horizon, 1 + # quantiles). + - loss: the mean squared error loss + quantile loss if future_target is provided. """ if return_dict is None: return_dict = self.config.use_return_dict @@ -927,7 +939,7 @@ def forward( if window_size is not None: new_inputs = [] if freq is not None: - new_freqs = [] + new_freqs = [] for i, ts in enumerate(inputs): new_inputs.extend(timesfm_moving_average(ts, window_size)) if freq is not None: @@ -969,6 +981,12 @@ def forward( mean_outputs = torch.maximum(mean_outputs, 0.0) full_outputs = torch.maximum(full_outputs, 0.0) + loss = None + if future_target is not None: + mse_loss = torch.nn.functional.mse_loss(mean_outputs, future_target) + quantile_loss = self._quantile_loss(full_outputs, future_target, self.config.quantiles) + loss = mse_loss + quantile_loss + if return_dict: return TimesFMOutputForPrediction( last_hidden_state=last_hidden_state, @@ -976,6 +994,7 @@ def forward( hidden_states=all_hidden_states if output_hidden_states else None, mean_predictions=mean_outputs, full_predictions=full_outputs, + loss=loss, ) else: return_tuple = [last_hidden_state] @@ -983,5 +1002,5 @@ def forward( return_tuple.append(all_hidden_states) if output_attentions: return_tuple.append(all_attentions) - return_tuple += [mean_outputs, full_outputs] + return_tuple += [mean_outputs, full_outputs, loss] return tuple(return_tuple) From 5fb1fe0a93e6726d6e0f51e4953bc6df4c1ed9af Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 5 Dec 2024 11:00:43 +0100 Subject: [PATCH 044/242] fix _quantile_loss --- .../models/timesfm/modeling_timesfm.py | 38 ++++++------------- 1 file changed, 12 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index b93810265e3d..e20dfa879599 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -45,7 +45,7 @@ class TimesFMDecoderOutput(BaseModelOutput): class TimesFMOutputForPrediction(BaseModelOutput): mean_predictions: torch.Tensor | None = None full_predictions: torch.Tensor | None = None - loss: float | None = None + loss: torch.Tensor | float | None = None class TimesFMTransformerMLP(nn.Module): @@ -74,12 +74,7 @@ def forward(self, x, paddings=None): class TimesFMResidualBlock(nn.Module): """TimesFM residual block.""" - def __init__( - self, - input_dims, - hidden_dims, - output_dims, - ): + def __init__(self, input_dims, hidden_dims, output_dims): super().__init__() self.input_dims = input_dims self.hidden_dims = hidden_dims @@ -572,7 +567,7 @@ def __init__(self, config: TimesFMConfig): self.input_ff_layer = TimesFMResidualBlock( input_dims=2 * config.patch_len, output_dims=config.model_dim, - hidden_dims=config.model_dim, + hidden_dims=config.intermediate_size, ) self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.model_dim) self.stacked_transformer = TimesFMStackedDecoder(config=config) @@ -605,15 +600,8 @@ def _forward_transform( return outputs, (mu, sigma) def _preprocess_input( - self, - input_ts: torch.Tensor, - input_padding: torch.Tensor, - ) -> tuple[ - torch.Tensor, - torch.Tensor, - tuple[torch.Tensor, torch.Tensor] | None, - torch.Tensor, - ]: + self, input_ts: torch.Tensor, input_padding: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], torch.Tensor]: """Preprocess input for stacked transformer.""" # Reshape into patches (using view for efficiency) @@ -707,7 +695,7 @@ def __init__(self, config: TimesFMConfig): self.horizon_ff_layer = TimesFMResidualBlock( input_dims=config.model_dim, output_dims=config.horizon_len * (1 + len(config.quantiles)), - hidden_dims=config.model_dim, + hidden_dims=config.intermediate_size, ) # Initialize weights and apply final processing @@ -757,9 +745,7 @@ def _preprocess( ) def _postprocess_output( - self, - model_output: torch.Tensor, - stats: tuple[torch.Tensor, torch.Tensor], + self, model_output: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor] ) -> torch.Tensor: """Postprocess output of stacked transformer.""" @@ -873,10 +859,9 @@ def decode( decoder_output.hidden_states, ) - @staticmethod - def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor, quantiles: List[float]) -> torch.Tensor: + def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: losses = [] - for q in quantiles: + for q in self.config.quantiles: errors = targets - predictions loss = torch.max((q - 1) * errors, q * errors) losses.append(loss.mean()) @@ -972,6 +957,7 @@ def forward( return_forecast_on_context=return_forecast_on_context, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + max_len=fcontext_len, ) if window_size is not None: @@ -983,8 +969,8 @@ def forward( loss = None if future_target is not None: - mse_loss = torch.nn.functional.mse_loss(mean_outputs, future_target) - quantile_loss = self._quantile_loss(full_outputs, future_target, self.config.quantiles) + mse_loss = F.mse_loss(mean_outputs, future_target) + quantile_loss = self._quantile_loss(full_outputs[:, :, 1:], future_target) loss = mse_loss + quantile_loss if return_dict: From 1445fe594aa6976dbf06d1a9028f7937448d55c0 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 5 Dec 2024 11:06:53 +0100 Subject: [PATCH 045/242] formatting --- docs/source/en/model_doc/timesfm.md | 5 +---- src/transformers/models/timesfm/modeling_timesfm.py | 7 ------- tests/models/timesfm/test_modeling_timesfm.py | 11 +---------- 3 files changed, 2 insertions(+), 21 deletions(-) diff --git a/docs/source/en/model_doc/timesfm.md b/docs/source/en/model_doc/timesfm.md index 76cf1f8afef2..4e2ee1ae0c61 100644 --- a/docs/source/en/model_doc/timesfm.md +++ b/docs/source/en/model_doc/timesfm.md @@ -25,11 +25,8 @@ The abstract from the paper is the following: *Motivated by recent advances in large language models for Natural Language Processing (NLP), we design a time-series foundation model for forecasting whose out-of-the-box zero-shot performance on a variety of public datasets comes close to the accuracy of state-of-the-art supervised forecasting models for each individual dataset. Our model is based on pretraining a patched-decoder style attention model on a large time-series corpus, and can work well across different forecasting history lengths, prediction lengths and temporal granularities.* -Tips: - - -This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). +This model was contributed by [kashif](https://huggingface.co/kashif). The original code can be found [here](https://github.com/google-research/timesfm). diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index e20dfa879599..ab8894730ca8 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -14,13 +14,6 @@ # limitations under the License. """PyTorch TimesFM model.""" - -#################################################### -# PyTorch Models are constructed by sub-classing -# - torch.nn.Module for the layers and -# - PreTrainedModel for the models (it-self a sub-class of nn.Module) -#################################################### - import logging import math from dataclasses import dataclass diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index 7fd8f01ff247..da534f4092ea 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -21,25 +21,16 @@ import torch from transformers import TimesFMConfig, is_torch_available -from transformers.testing_utils import ( - require_torch, - torch_device, -) +from transformers.testing_utils import require_torch, torch_device from transformers.utils import is_torch_fx_available -# from ...generation.test_utils import GenerationTesterMixin -# define our own GenerationTesters from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin -# from ...test_pipeline_mixin import PipelineTesterMixin - - if is_torch_fx_available(): pass - if is_torch_available(): from transformers import TimesFMModelForPrediction From 1c1804e421baaff676b7e650d2ac291eda48e57c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 5 Dec 2024 11:17:56 +0100 Subject: [PATCH 046/242] fix isort --- src/transformers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 17a9426554de..37eac1fcd803 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3502,8 +3502,8 @@ ) _import_structure["models.timesfm"].extend( [ - "TimesFMModelForPrediction", "TimesFMDecoder", + "TimesFMModelForPrediction", "TimesFMPreTrainedModel", ] ) From 35a7e9f8aa4c42ac0bfb700df2c11d4314c3e0cf Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 5 Dec 2024 12:13:47 +0100 Subject: [PATCH 047/242] add weight init --- .../models/timesfm/modeling_timesfm.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index ab8894730ca8..b64bde1288da 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -546,6 +546,51 @@ def _init_weights(self, module): elif isinstance(module, TimesFMRMSNorm): nn.init.zeros_(module.weight) + elif isinstance(module, TimesFMTransformerMLP): + # Initialize gate projection + module.gate_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor) + if module.gate_proj.bias is not None: + nn.init.zeros_(module.gate_proj.bias) + + # Initialize down projection + module.down_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor) + if module.down_proj.bias is not None: + nn.init.zeros_(module.down_proj.bias) + + # Initialize layer norm + nn.init.ones_(module.layer_norm.weight) + nn.init.zeros_(module.layer_norm.bias) + + elif isinstance(module, TimesFMAttention): + # Initialize qkv projection + module.qkv_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor) + if module.qkv_proj.bias is not None: + nn.init.zeros_(module.qkv_proj.bias) + + # Initialize output projection + module.o_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor) + if module.o_proj.bias is not None: + nn.init.zeros_(module.o_proj.bias) + + # Initialize scaling parameter + nn.init.ones_(module.scaling) + + elif isinstance(module, TimesFMResidualBlock): + # Initialize hidden layer + module.hidden_layer[0].weight.data.normal_(mean=0, std=self.config.initializer_factor) + if module.hidden_layer[0].bias is not None: + nn.init.zeros_(module.hidden_layer[0].bias) + + # Initialize output layer + module.output_layer.weight.data.normal_(mean=0, std=self.config.initializer_factor) + if module.output_layer.bias is not None: + nn.init.zeros_(module.output_layer.bias) + + # Initialize residual layer + module.residual_layer.weight.data.normal_(mean=0, std=self.config.initializer_factor) + if module.residual_layer.bias is not None: + nn.init.zeros_(module.residual_layer.bias) + elif isinstance(module, TimesFMPositionalEmbedding): pass From 6c4ddeda03ae0f2665deafb82c3b421bd09b43bd Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 8 Dec 2024 13:44:02 +0100 Subject: [PATCH 048/242] add support for sdpa and flash_attention_2 --- .../models/timesfm/configuration_timesfm.py | 4 + .../models/timesfm/modeling_timesfm.py | 194 +++++++++++++++++- 2 files changed, 187 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index eb135c03c968..012315882957 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -62,6 +62,8 @@ class TimesFMConfig(PretrainedConfig): The quantiles to predict. pad_val (`float`, *optional*, defaults to 1123581321.0): The value used to pad the predictions. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for the attention scores. use_positional_embedding (`bool`, *optional*, defaults to `True`): Whether to add positional embeddings. initializer_factor (`float`, *optional*, defaults to 1.0): @@ -93,6 +95,7 @@ def __init__( rms_norm_eps: float = 1e-6, quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], pad_val: float = 1123581321.0, + attention_dropout: float = 0.0, use_positional_embedding: bool = True, initializer_factor: float = 1.0, **kwargs, @@ -110,6 +113,7 @@ def __init__( self.num_heads = num_heads self.tolerance = tolerance self.rms_norm_eps = rms_norm_eps + self.attention_dropout = attention_dropout self.use_positional_embedding = use_positional_embedding self.initializer_factor = initializer_factor diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index b64bde1288da..38c8dc9bb6cc 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -25,9 +25,14 @@ from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel +from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10 from .configuration_timesfm import TimesFMConfig +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + @dataclass class TimesFMDecoderOutput(BaseModelOutput): loc: torch.Tensor | None = None @@ -188,9 +193,7 @@ def __init__(self, config: TimesFMConfig): self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = nn.Parameter( - torch.empty((self.head_dim,), dtype=torch.float32), - ) + self.scaling = nn.Parameter(torch.empty((self.head_dim,))) self.qkv_proj = nn.Linear( self.hidden_size, @@ -198,12 +201,8 @@ def __init__(self, config: TimesFMConfig): ) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size) - def _per_dim_scaling(self, query: torch.Tensor) -> torch.Tensor: - # [batch_size, n_local_heads, input_len, head_dim] - r_softplus_0 = 1.442695041 - softplus_func = torch.nn.Softplus() - scale = r_softplus_0 / math.sqrt(self.head_dim) - scale = scale * softplus_func(self.scaling) + def _scale_query(self, query: torch.Tensor) -> torch.Tensor: + scale = F.softplus(self.scaling).mul(1.442695041 / math.sqrt(self.head_dim)) return query * scale[None, None, None, :] def forward( @@ -225,7 +224,7 @@ def forward( xq = xq.view(batch_size, -1, self.num_heads, self.head_dim) xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim) xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim) - xq = self._per_dim_scaling(xq) + xq = self._scale_query(xq) # Write new kv cache. # [batch_size, input_len, n_local_kv_heads, head_dim] @@ -272,12 +271,182 @@ def forward( return output, scores +class TimesFMFlashAttention2(TimesFMAttention): + """TimesFM attention implementation using Flash Attention 2.""" + + def __init__(self, config: TimesFMConfig): + super().__init__(config) + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + kv_write_indices: torch.Tensor | None = None, + kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + if output_attentions: + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + kv_write_indices=kv_write_indices, + kv_cache=kv_cache, + output_attentions=output_attentions, + ) + + batch_size, seq_length, _ = hidden_states.shape + + # Project to q, k, v + qkv = self.qkv_proj(hidden_states) + xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # Reshape + xq = xq.view(batch_size, seq_length, self.num_heads, self.head_dim) + xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim) + xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim) + + # Scale query using the model's learned scaling + xq = self._scale_query(xq) + + # Handle KV cache + if kv_cache is not None and kv_write_indices is not None: + k_cache, v_cache = kv_cache + k_cache.index_copy_(1, kv_write_indices, xk) + v_cache.index_copy_(1, kv_write_indices, xv) + key = k_cache + value = v_cache + else: + key = xk + value = xv + + # Handle grouped attention + if self.num_queries_per_kv > 1: + key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) + value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2) + + # Transpose for attention + query = xq.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # Run flash attention + attn_output = _flash_attention_forward( + query, + key, + value, + attention_mask, + seq_length, + dropout_p=self.attention_dropout if self.training else 0.0, + softmax_scale=1, # Set to 1.0 to disable default scaling + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + # Reshape output + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, -1) + attn_output = self.o_proj(attn_output) + + return attn_output, None + + +class TimesFMSdpaAttention(TimesFMAttention): + """TimesFM attention implementation using torch.nn.functional.scaled_dot_product_attention.""" + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + kv_write_indices: torch.Tensor | None = None, + kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + if output_attentions: + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + kv_write_indices=kv_write_indices, + kv_cache=kv_cache, + output_attentions=output_attentions, + ) + + hidden_states_shape = hidden_states.shape + batch_size, seq_length, _ = hidden_states_shape + + # Project to queries, keys, values + qkv = self.qkv_proj(hidden_states) + xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # Reshape: [batch_size, seq_length, num_heads * head_dim] -> [batch_size, seq_length, num_heads, head_dim] + xq = xq.view(batch_size, seq_length, self.num_heads, self.head_dim) + xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim) + xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim) + + # Scale query exactly as in original + xq = self._scale_query(xq) + + # Handle KV cache + if kv_cache is not None and kv_write_indices is not None: + k_cache, v_cache = kv_cache + k_cache.index_copy_(1, kv_write_indices, xk) + v_cache.index_copy_(1, kv_write_indices, xv) + key = k_cache + value = v_cache + else: + key = xk + value = xv + + # Handle grouped attention + if self.num_queries_per_kv > 1: + key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) + value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2) + + # Transpose for attention: [batch_size, num_heads, seq_length, head_dim] + query = xq.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # Make inputs contiguous + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # Run scaled dot-product attention + # Note: attention_mask should already be in the correct format from TimesFMStackedDecoder + attn_output = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=False, # We use the provided attention mask + scale=1, # We already scaled the query + ) + + # Reshape output: [batch_size, seq_length, hidden_size] + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, -1) + attn_output = self.o_proj(attn_output) + + return attn_output, None + + +TIMESFM_ATTENTION_CLASSES = { + "eager": TimesFMAttention, + "flash_attention_2": TimesFMFlashAttention2, + "sdpa": TimesFMSdpaAttention, +} + + class TimesFMDecoderLayer(nn.Module): """Transformer layer.""" def __init__(self, config: TimesFMConfig): super().__init__() - self.self_attn = TimesFMAttention(config) + + if config._attn_implementation not in TIMESFM_ATTENTION_CLASSES: + raise ValueError(f"Unknown attention implementation: {config._attn_implementation}") + attention_class = TIMESFM_ATTENTION_CLASSES[config._attn_implementation] + + self.self_attn = attention_class(config) self.mlp = TimesFMTransformerMLP(config.model_dim, config.intermediate_size) self.input_layernorm = TimesFMRMSNorm(config.model_dim, eps=config.rms_norm_eps) @@ -529,6 +698,7 @@ class TimesFMPreTrainedModel(PreTrainedModel): config_class = TimesFMConfig base_model_prefix = "timesfm" main_input_name = "inputs" + _supports_sdpa = True def _init_weights(self, module): if isinstance(module, nn.Embedding): @@ -720,6 +890,8 @@ def forward( class TimesFMModelForPrediction(TimesFMPreTrainedModel): + """TimesFM model for quantile and mean prediction.""" + def __init__(self, config: TimesFMConfig): super().__init__(config) From 82c697cbc552703a7700161789a6433238a73b20 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 8 Dec 2024 14:20:51 +0100 Subject: [PATCH 049/242] fixes for flash_attention --- .../models/timesfm/modeling_timesfm.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 38c8dc9bb6cc..e1da8515b6b9 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -330,6 +330,14 @@ def forward( key = key.transpose(1, 2) value = value.transpose(1, 2) + # Convert attention mask to proper format for Flash Attention + if attention_mask is not None: + # Convert from [batch_size, 1, seq_length, seq_length] to [batch_size, seq_length] + # by checking which positions are not allowed to attend to any other position + attention_mask = attention_mask.squeeze(1) # [batch_size, seq_length, seq_length] + attention_mask = ~attention_mask.all(dim=-1) # [batch_size, seq_length] + attention_mask = attention_mask.to(query.dtype) + # Run flash attention attn_output = _flash_attention_forward( query, @@ -337,9 +345,10 @@ def forward( value, attention_mask, seq_length, - dropout_p=self.attention_dropout if self.training else 0.0, + dropout=self.attention_dropout if self.training else 0.0, softmax_scale=1, # Set to 1.0 to disable default scaling use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=False, ) # Reshape output @@ -699,6 +708,7 @@ class TimesFMPreTrainedModel(PreTrainedModel): base_model_prefix = "timesfm" main_input_name = "inputs" _supports_sdpa = True + _supports_flash_attn_2 = True def _init_weights(self, module): if isinstance(module, nn.Embedding): @@ -938,8 +948,8 @@ def _preprocess( padding = torch.zeros(input_len + self.horizon_len, dtype=torch.float32) if input_len < self.context_len: num_front_pad = self.context_len - input_len - ts = torch.cat([torch.zeros(num_front_pad, dtype=torch.float32), ts], dim=0) - padding = torch.cat([torch.ones(num_front_pad, dtype=torch.float32), padding], dim=0) + ts = torch.cat([torch.zeros(num_front_pad, dtype=torch.float32, device=ts.device), ts], dim=0) + padding = torch.cat([torch.ones(num_front_pad, dtype=torch.float32, device=padding.device), padding], dim=0) elif input_len > self.context_len: ts = ts[-self.context_len :] padding = padding[-(self.context_len + self.horizon_len) :] From 3bedf98079860d8c48f42ecdcac3288c400abbfc Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 8 Dec 2024 14:21:26 +0100 Subject: [PATCH 050/242] formatting --- src/transformers/models/timesfm/modeling_timesfm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index e1da8515b6b9..3f60f435b482 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -949,7 +949,9 @@ def _preprocess( if input_len < self.context_len: num_front_pad = self.context_len - input_len ts = torch.cat([torch.zeros(num_front_pad, dtype=torch.float32, device=ts.device), ts], dim=0) - padding = torch.cat([torch.ones(num_front_pad, dtype=torch.float32, device=padding.device), padding], dim=0) + padding = torch.cat( + [torch.ones(num_front_pad, dtype=torch.float32, device=padding.device), padding], dim=0 + ) elif input_len > self.context_len: ts = ts[-self.context_len :] padding = padding[-(self.context_len + self.horizon_len) :] From 2b4f55cb85e034b316469d7ba8bb9a41522f11e3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 9 Dec 2024 10:40:04 +0100 Subject: [PATCH 051/242] remove flash_attention --- .../timesfm/convert_timesfm_orignal_to_hf.py | 35 ++++-- .../models/timesfm/modeling_timesfm.py | 102 +----------------- 2 files changed, 27 insertions(+), 110 deletions(-) diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py index eeed750c337b..7284f9c24c17 100644 --- a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py @@ -26,7 +26,7 @@ def write_model(model_path, safe_serialization=True): tfm = timesfm.TimesFm( hparams=timesfm.TimesFmHparams( - backend="cpu", + backend="cuda" if torch.cuda.is_available() else "cpu", per_core_batch_size=32, horizon_len=128, ), @@ -139,7 +139,7 @@ def check_outputs(model_path): # Load original model tfm = timesfm.TimesFm( hparams=timesfm.TimesFmHparams( - backend="cpu", + backend="cuda" if torch.cuda.is_available() else "cpu", per_core_batch_size=32, horizon_len=128, ), @@ -147,7 +147,11 @@ def check_outputs(model_path): ) # Load converted model - converted_model = TimesFMModelForPrediction.from_pretrained(model_path) + converted_model = TimesFMModelForPrediction.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + attn_implementation="sdpa", + ).to("cuda" if torch.cuda.is_available() else "cpu") converted_model.eval() # Set to evaluation mode # Create test inputs @@ -165,14 +169,19 @@ def check_outputs(model_path): ) # Convert inputs to sequence of tensors - forecast_input_tensor = [torch.tensor(ts, dtype=torch.float32) for ts in forecast_input] - frequency_input_tensor = torch.tensor(frequency_input, dtype=torch.long) + forecast_input_tensor = [ + torch.tensor(ts, dtype=torch.bfloat16).to("cuda" if torch.cuda.is_available() else "cpu") + for ts in forecast_input + ] + frequency_input_tensor = torch.tensor(frequency_input, dtype=torch.long).to( + "cuda" if torch.cuda.is_available() else "cpu" + ) # Get predictions from converted model with torch.no_grad(): outputs = converted_model(inputs=forecast_input_tensor, freq=frequency_input_tensor, return_dict=True) - point_forecast_conv = outputs.mean_predictions.numpy() - quantile_forecast_conv = outputs.full_predictions.numpy() + point_forecast_conv = outputs.mean_predictions.float().cpu().numpy() + quantile_forecast_conv = outputs.full_predictions.float().cpu().numpy() # Compare outputs point_forecast_diff = np.abs(point_forecast_orig - point_forecast_conv) @@ -221,11 +230,15 @@ def main(): "--safe_serialization", type=bool, default=True, help="Whether or not to save using `safetensors`." ) args = parser.parse_args() - write_model( - model_path=args.output_dir, - safe_serialization=args.safe_serialization, - ) + # if the saved model file exists, skip the conversion + if os.path.exists(os.path.join(args.output_dir, "model.safetensors" if args.safe_serialization else "model.bin")): + print(f"Model already exists in {args.output_dir}, skipping conversion.") + else: + write_model( + model_path=args.output_dir, + safe_serialization=args.safe_serialization, + ) check_outputs(args.output_dir) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 3f60f435b482..f388e456bc8c 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -25,14 +25,9 @@ from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel -from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10 from .configuration_timesfm import TimesFMConfig -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - @dataclass class TimesFMDecoderOutput(BaseModelOutput): loc: torch.Tensor | None = None @@ -271,93 +266,6 @@ def forward( return output, scores -class TimesFMFlashAttention2(TimesFMAttention): - """TimesFM attention implementation using Flash Attention 2.""" - - def __init__(self, config: TimesFMConfig): - super().__init__(config) - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - kv_write_indices: torch.Tensor | None = None, - kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, - output_attentions: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - if output_attentions: - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - kv_write_indices=kv_write_indices, - kv_cache=kv_cache, - output_attentions=output_attentions, - ) - - batch_size, seq_length, _ = hidden_states.shape - - # Project to q, k, v - qkv = self.qkv_proj(hidden_states) - xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - # Reshape - xq = xq.view(batch_size, seq_length, self.num_heads, self.head_dim) - xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim) - xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim) - - # Scale query using the model's learned scaling - xq = self._scale_query(xq) - - # Handle KV cache - if kv_cache is not None and kv_write_indices is not None: - k_cache, v_cache = kv_cache - k_cache.index_copy_(1, kv_write_indices, xk) - v_cache.index_copy_(1, kv_write_indices, xv) - key = k_cache - value = v_cache - else: - key = xk - value = xv - - # Handle grouped attention - if self.num_queries_per_kv > 1: - key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) - value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2) - - # Transpose for attention - query = xq.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - # Convert attention mask to proper format for Flash Attention - if attention_mask is not None: - # Convert from [batch_size, 1, seq_length, seq_length] to [batch_size, seq_length] - # by checking which positions are not allowed to attend to any other position - attention_mask = attention_mask.squeeze(1) # [batch_size, seq_length, seq_length] - attention_mask = ~attention_mask.all(dim=-1) # [batch_size, seq_length] - attention_mask = attention_mask.to(query.dtype) - - # Run flash attention - attn_output = _flash_attention_forward( - query, - key, - value, - attention_mask, - seq_length, - dropout=self.attention_dropout if self.training else 0.0, - softmax_scale=1, # Set to 1.0 to disable default scaling - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=False, - ) - - # Reshape output - attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, -1) - attn_output = self.o_proj(attn_output) - - return attn_output, None - - class TimesFMSdpaAttention(TimesFMAttention): """TimesFM attention implementation using torch.nn.functional.scaled_dot_product_attention.""" @@ -440,7 +348,6 @@ def forward( TIMESFM_ATTENTION_CLASSES = { "eager": TimesFMAttention, - "flash_attention_2": TimesFMFlashAttention2, "sdpa": TimesFMSdpaAttention, } @@ -708,7 +615,6 @@ class TimesFMPreTrainedModel(PreTrainedModel): base_model_prefix = "timesfm" main_input_name = "inputs" _supports_sdpa = True - _supports_flash_attn_2 = True def _init_weights(self, module): if isinstance(module, nn.Embedding): @@ -945,13 +851,11 @@ def _preprocess( for i, ts in enumerate(inputs): input_len = ts.shape[0] - padding = torch.zeros(input_len + self.horizon_len, dtype=torch.float32) + padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device) if input_len < self.context_len: num_front_pad = self.context_len - input_len - ts = torch.cat([torch.zeros(num_front_pad, dtype=torch.float32, device=ts.device), ts], dim=0) - padding = torch.cat( - [torch.ones(num_front_pad, dtype=torch.float32, device=padding.device), padding], dim=0 - ) + ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0) elif input_len > self.context_len: ts = ts[-self.context_len :] padding = padding[-(self.context_len + self.horizon_len) :] From 5c4c5917b50698f774ffacf9dde0ae4a4fc2eb24 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 9 Dec 2024 11:03:25 +0100 Subject: [PATCH 052/242] fix tests --- src/transformers/models/timesfm/modeling_timesfm.py | 8 ++++++-- tests/models/timesfm/test_modeling_timesfm.py | 12 +++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index f388e456bc8c..35cfd0d537b7 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -177,6 +177,8 @@ class TimesFMAttention(nn.Module): def __init__(self, config: TimesFMConfig): super().__init__() + self.attention_dropout = config.attention_dropout + self.num_heads = config.num_heads self.num_kv_heads = config.num_heads @@ -260,7 +262,7 @@ def forward( output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) output = self.o_proj(output) - if output_attentions: + if not output_attentions: scores = None return output, scores @@ -373,6 +375,7 @@ def forward( paddings: torch.Tensor, kv_write_indices: torch.Tensor | None = None, kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, + output_attentions: bool = False, ) -> torch.Tensor: # Self Attention residual = hidden_states @@ -382,6 +385,7 @@ def forward( attention_mask=attention_mask, kv_write_indices=kv_write_indices, kv_cache=kv_cache, + output_attentions=output_attentions, ) hidden_states = residual + hidden_states @@ -423,6 +427,7 @@ def forward( paddings=paddings, kv_write_indices=kv_write_indices, kv_cache=kv_cache, + output_attentions=output_attentions, ) if output_attentions: all_attentions.append(scores) @@ -727,7 +732,6 @@ def _preprocess_input( self, input_ts: torch.Tensor, input_padding: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], torch.Tensor]: """Preprocess input for stacked transformer.""" - # Reshape into patches (using view for efficiency) bsize = input_ts.shape[0] patched_inputs = input_ts.view(bsize, -1, self.config.patch_len) diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index da534f4092ea..ea734402c1e5 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -107,15 +107,13 @@ def get_pipeline_config(self): def prepare_config_and_inputs(self): forecast_input = [ - torch.tensor(np.sin(np.linspace(0, 20, 100)), dtype=torch.float32), - torch.tensor(np.sin(np.linspace(0, 20, 200)), dtype=torch.float32), - torch.tensor(np.sin(np.linspace(0, 20, 400)), dtype=torch.float32), + torch.tensor(np.sin(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device), + torch.tensor(np.cos(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device), + torch.tensor(np.tan(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device), ] - frequency_input = [0, 1, 2] + frequency_input = torch.tensor([0, 1, 2], dtype=torch.long, device=torch_device) - config = self.get_config() - - return (config, forecast_input, frequency_input) + return (self.get_config(), torch.stack(forecast_input, dim=0), frequency_input) def prepare_config_and_inputs_for_common(self): (config, forecast_input, frequency_input) = self.prepare_config_and_inputs() From f924a317386775552e7b9d9013aa380a52343e72 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 9 Dec 2024 11:06:31 +0100 Subject: [PATCH 053/242] fix file name --- .../models/timesfm/convert_timesfm_orignal_to_hf.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py index 7284f9c24c17..f791fb0236aa 100644 --- a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py @@ -3,9 +3,10 @@ import shutil import numpy as np -import timesfm + import torch +import timesfm from transformers import TimesFMConfig, TimesFMModelForPrediction @@ -232,7 +233,9 @@ def main(): args = parser.parse_args() # if the saved model file exists, skip the conversion - if os.path.exists(os.path.join(args.output_dir, "model.safetensors" if args.safe_serialization else "model.bin")): + if os.path.exists( + os.path.join(args.output_dir, "model.safetensors" if args.safe_serialization else "pytorch_model.bin") + ): print(f"Model already exists in {args.output_dir}, skipping conversion.") else: write_model( From 84f763e61f60758299e8ca28e13cbf83a3d7e403 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 9 Dec 2024 12:56:22 +0100 Subject: [PATCH 054/242] fix quantile loss --- .../models/timesfm/modeling_timesfm.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 35cfd0d537b7..3bbba4885c5e 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -851,7 +851,7 @@ def _preprocess( - the number of padded examples for SPMD so that each core has the same number (a multiple of `batch_size`) of examples. """ - input_ts, input_padding, inp_freq = [], [], [] + input_ts, input_padding = [], [] for i, ts in enumerate(inputs): input_len = ts.shape[0] @@ -866,12 +866,11 @@ def _preprocess( input_ts.append(ts) input_padding.append(padding) - inp_freq.append(freq[i]) return ( torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0), - torch.tensor(inp_freq, dtype=torch.int32).reshape(-1, 1), + torch.tensor(freq, dtype=torch.int32, device=input_ts[0].device).reshape(-1, 1), ) def _postprocess_output( @@ -991,8 +990,8 @@ def decode( def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: losses = [] - for q in self.config.quantiles: - errors = targets - predictions + for i, q in enumerate(self.config.quantiles): + errors = targets - predictions[:, :, i] loss = torch.max((q - 1) * errors, q * errors) losses.append(loss.mean()) return torch.stack(losses).mean() @@ -1074,11 +1073,6 @@ def forward( input_ts, input_padding, inp_freq = self._preprocess(inputs, freq) - # Move tensors to the same device as input - input_ts = input_ts.to(device) - input_padding = input_padding.to(device) - inp_freq = inp_freq.to(device) - mean_outputs, full_outputs, last_hidden_state, all_attentions, all_hidden_states = self.decode( input_ts=input_ts, paddings=input_padding, From ee1e2896a2896c9952212e34be22c902fe1dbfed Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 9 Dec 2024 14:26:52 +0100 Subject: [PATCH 055/242] added initial TimesFMModelIntegrationTests --- .../models/timesfm/modeling_timesfm.py | 14 +++++--- tests/models/timesfm/test_modeling_timesfm.py | 36 +++++++++++++++++-- 2 files changed, 43 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 3bbba4885c5e..35cfd0d537b7 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -851,7 +851,7 @@ def _preprocess( - the number of padded examples for SPMD so that each core has the same number (a multiple of `batch_size`) of examples. """ - input_ts, input_padding = [], [] + input_ts, input_padding, inp_freq = [], [], [] for i, ts in enumerate(inputs): input_len = ts.shape[0] @@ -866,11 +866,12 @@ def _preprocess( input_ts.append(ts) input_padding.append(padding) + inp_freq.append(freq[i]) return ( torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0), - torch.tensor(freq, dtype=torch.int32, device=input_ts[0].device).reshape(-1, 1), + torch.tensor(inp_freq, dtype=torch.int32).reshape(-1, 1), ) def _postprocess_output( @@ -990,8 +991,8 @@ def decode( def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: losses = [] - for i, q in enumerate(self.config.quantiles): - errors = targets - predictions[:, :, i] + for q in self.config.quantiles: + errors = targets - predictions loss = torch.max((q - 1) * errors, q * errors) losses.append(loss.mean()) return torch.stack(losses).mean() @@ -1073,6 +1074,11 @@ def forward( input_ts, input_padding, inp_freq = self._preprocess(inputs, freq) + # Move tensors to the same device as input + input_ts = input_ts.to(device) + input_padding = input_padding.to(device) + inp_freq = inp_freq.to(device) + mean_outputs, full_outputs, last_hidden_state, all_attentions, all_hidden_states = self.decode( input_ts=input_ts, paddings=input_padding, diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index ea734402c1e5..834748218f2b 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -20,8 +20,9 @@ import numpy as np import torch +from huggingface_hub import hf_hub_download from transformers import TimesFMConfig, is_torch_available -from transformers.testing_utils import require_torch, torch_device +from transformers.testing_utils import require_torch, slow, torch_device from transformers.utils import is_torch_fx_available from ...test_configuration_common import ConfigTester @@ -32,7 +33,9 @@ pass if is_torch_available(): - from transformers import TimesFMModelForPrediction + from transformers import TimesFMDecoder, TimesFMModelForPrediction + +TOLERANCE = 1e-4 class TimesFMModelTester: @@ -46,7 +49,7 @@ def __init__( num_layers: int = 1, model_dim: int = 16, intermediate_size: int = 32, - head_dim: int = 2, + head_dim: int = 8, num_heads: int = 2, tolerance: float = 1e-6, rms_norm_eps: float = 1e-6, @@ -163,3 +166,30 @@ def test_model_main_input_name(self): # The main input is the name of the argument after `self` observed_main_input_name = list(model_signature.parameters.keys())[1] self.assertEqual(TimesFMModelForPrediction.main_input_name, observed_main_input_name) + + +@require_torch +@slow +class TimesFMModelIntegrationTests(unittest.TestCase): + @classmethod + def load_batch(cls, filename="train-batch.pt"): + file = hf_hub_download( + repo_id="hf-internal-testing/tourism-monthly-batch", filename=filename, repo_type="dataset" + ) + batch = torch.load(file, map_location=torch_device) + return batch + + def test_inference_no_head(self): + model = TimesFMModelForPrediction.from_pretrained("huggingface/timesfm-tourism-monthly").to(torch_device) + batch = self.load_batch() + with torch.no_grad(): + inputs = batch["past_values"] + output = model(inputs=inputs).last_hidden_state + self.assertEqual( + output.shape, torch.Size([64, model.config.context_len // model.config.patch_len, model.config.model_dim]) + ) + + expected_slice = torch.tensor( + [[-4.0141, 3.3141, 1.9321], [-4.9121, 3.1443, 2.0836], [-5.1142, 2.7376, 2.1566]], device=torch_device + ) + self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE)) From b9e9633c751ef3f13c7dfc3661ae228322b90d2b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 10 Dec 2024 10:13:40 +0100 Subject: [PATCH 056/242] fix formatting --- tests/models/timesfm/test_modeling_timesfm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index 834748218f2b..46f758254bbd 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -19,8 +19,8 @@ import numpy as np import torch - from huggingface_hub import hf_hub_download + from transformers import TimesFMConfig, is_torch_available from transformers.testing_utils import require_torch, slow, torch_device from transformers.utils import is_torch_fx_available @@ -33,7 +33,7 @@ pass if is_torch_available(): - from transformers import TimesFMDecoder, TimesFMModelForPrediction + from transformers import TimesFMModelForPrediction TOLERANCE = 1e-4 From c9dede6252e7a61a2853a15bbe376878724bb236 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 22 Dec 2024 14:47:57 +0100 Subject: [PATCH 057/242] fix import order --- .../models/timesfm/convert_timesfm_orignal_to_hf.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py index f791fb0236aa..f4256ba653c6 100644 --- a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py @@ -3,10 +3,9 @@ import shutil import numpy as np - +import timesfm import torch -import timesfm from transformers import TimesFMConfig, TimesFMModelForPrediction From bc6779761ed1095a86b39d6642fb945fcb78b02c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 22 Dec 2024 15:01:00 +0100 Subject: [PATCH 058/242] fix _quantile_loss --- src/transformers/models/timesfm/modeling_timesfm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 35cfd0d537b7..c9ca29bf4b44 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -991,8 +991,8 @@ def decode( def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: losses = [] - for q in self.config.quantiles: - errors = targets - predictions + for i, q in enumerate(self.config.quantiles): + errors = targets - predictions[..., i] loss = torch.max((q - 1) * errors, q * errors) losses.append(loss.mean()) return torch.stack(losses).mean() From 61d5e89945dad2796b7bd77a74115c3f61f0278a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 22 Dec 2024 15:20:47 +0100 Subject: [PATCH 059/242] add doc for SDPA --- docs/source/en/perf_infer_gpu_one.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 930f41b6fefb..ffa4375cefca 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -294,6 +294,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) * [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel) * [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model) +* [TimesFM](https://huggingface.co/docs/transformers/model_doc/timesfm#transformers.TimesFMModel) * [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) * [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel) * [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel) From ece0896aa1f74617a9bd439b0c0ee1671fef4e0e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 3 Jan 2025 23:46:44 +0100 Subject: [PATCH 060/242] use timesfm 2.0 --- .../timesfm/convert_timesfm_orignal_to_hf.py | 28 +++++++++++++++---- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py index f4256ba653c6..b48e1b8c7dc8 100644 --- a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py @@ -19,7 +19,7 @@ """ -def write_model(model_path, safe_serialization=True): +def write_model(model_path, safe_serialization=True, huggingface_repo_id="google/timesfm-2.0-500m-pytorch"): os.makedirs(model_path, exist_ok=True) tmp_model_path = os.path.join(model_path, "tmp") os.makedirs(tmp_model_path, exist_ok=True) @@ -29,8 +29,13 @@ def write_model(model_path, safe_serialization=True): backend="cuda" if torch.cuda.is_available() else "cpu", per_core_batch_size=32, horizon_len=128, + input_patch_len=32, + output_patch_len=128, + num_layers=50, + model_dims=1280, + use_positional_embedding=False, ), - checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id="google/timesfm-1.0-200m-pytorch"), + checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=huggingface_repo_id), ) timesfm_config = TimesFMConfig( @@ -42,6 +47,7 @@ def write_model(model_path, safe_serialization=True): intermediate_size=tfm.hparams.model_dims, head_dim=tfm.hparams.model_dims // tfm.hparams.num_heads, num_heads=tfm.hparams.num_heads, + use_positional_embedding=tfm.hparams.use_positional_embedding, ) timesfm_config.save_pretrained(tmp_model_path) timesfm_model = TimesFMModelForPrediction(timesfm_config) @@ -132,7 +138,7 @@ def write_model(model_path, safe_serialization=True): shutil.rmtree(tmp_model_path) -def check_outputs(model_path): +def check_outputs(model_path, huggingface_repo_id): """Compares outputs between original and converted models.""" print("\nChecking model outputs...") @@ -142,8 +148,13 @@ def check_outputs(model_path): backend="cuda" if torch.cuda.is_available() else "cpu", per_core_batch_size=32, horizon_len=128, + input_patch_len=32, + output_patch_len=128, + num_layers=50, + model_dims=1280, + use_positional_embedding=False, ), - checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id="google/timesfm-1.0-200m-pytorch"), + checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=huggingface_repo_id), ) # Load converted model @@ -229,6 +240,12 @@ def main(): parser.add_argument( "--safe_serialization", type=bool, default=True, help="Whether or not to save using `safetensors`." ) + parser.add_argument( + "--huggingface_repo_id", + type=str, + default="google/timesfm-2.0-500m-pytorch", + help="The Hugging Face repository ID to use for the model.", + ) args = parser.parse_args() # if the saved model file exists, skip the conversion @@ -240,8 +257,9 @@ def main(): write_model( model_path=args.output_dir, safe_serialization=args.safe_serialization, + huggingface_repo_id=args.huggingface_repo_id, ) - check_outputs(args.output_dir) + check_outputs(args.output_dir, args.huggingface_repo_id) if __name__ == "__main__": From 21e3236f3e16d3afd9600539f099b0b821bb624f Mon Sep 17 00:00:00 2001 From: Rajat Sen Date: Fri, 3 Jan 2025 23:33:20 +0000 Subject: [PATCH 061/242] bug fix in timesfm decode function. --- src/transformers/models/timesfm/modeling_timesfm.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index c9ca29bf4b44..15a1d901797c 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -900,7 +900,7 @@ def decode( freq: torch.LongTensor, horizon_len: int, output_patch_len: int | None = None, - max_len: int = 512, + max_len: int | None = None, return_forecast_on_context: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, @@ -914,7 +914,8 @@ def decode( horizon_len: prediction length. output_patch_len: output length to be fetched from one step of auto-regressive decoding. - max_len: maximum training context length. + max_len: maximum training context length. If None, then we use the length + of the initial context as max length. return_forecast_on_context: whether to return the model forecast on the context except the first input patch. @@ -940,6 +941,9 @@ def decode( ) if output_patch_len is None: output_patch_len = self.config.horizon_len + if max_len is None: + max_len = context_len + num_decode_patches = (horizon_len + output_patch_len - 1) // output_patch_len for step_index in range(num_decode_patches): current_padding = paddings[:, 0 : final_out.shape[1]] @@ -961,7 +965,8 @@ def decode( # For the first decodings step, collect the model forecast on the # context except the unavailable first input batch forecast. new_full_ts = fprop_outputs[:, :-1, : self.config.patch_len, :] - new_full_ts = fprop_outputs.view(new_full_ts.size(0), -1, new_full_ts.size(3)) + # We have to use reshape and not view for non-contiguous memory + new_full_ts = new_full_ts.reshape(new_full_ts.size(0), -1, new_full_ts.size(3)) full_outputs.append(new_full_ts) From f173a8e1a498e68699fcd5351846254c267e70f8 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 6 Jan 2025 14:34:34 +0100 Subject: [PATCH 062/242] compare mean forecasts --- src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py index b48e1b8c7dc8..5309caca77a8 100644 --- a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py @@ -153,6 +153,7 @@ def check_outputs(model_path, huggingface_repo_id): num_layers=50, model_dims=1280, use_positional_embedding=False, + point_forecast_mode="mean", ), checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=huggingface_repo_id), ) From b83023c413e7aa2f6807d989533c965e5b27724a Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Thu, 9 Jan 2025 17:06:59 -0800 Subject: [PATCH 063/242] refactor type hints, use CamelCase --- docs/source/en/model_doc/timesfm.md | 12 +- docs/source/en/perf_infer_gpu_one.md | 2 +- src/transformers/__init__.py | 16 +- .../models/auto/configuration_auto.py | 4 +- src/transformers/models/auto/modeling_auto.py | 4 +- src/transformers/models/timesfm/__init__.py | 12 +- .../models/timesfm/configuration_timesfm.py | 6 +- .../timesfm/convert_timesfm_orignal_to_hf.py | 8 +- .../models/timesfm/modeling_timesfm.py | 140 +++++++++--------- src/transformers/utils/dummy_pt_objects.py | 6 +- tests/models/timesfm/test_modeling_timesfm.py | 32 ++-- 11 files changed, 121 insertions(+), 121 deletions(-) diff --git a/docs/source/en/model_doc/timesfm.md b/docs/source/en/model_doc/timesfm.md index 4e2ee1ae0c61..3edf1aedbb0a 100644 --- a/docs/source/en/model_doc/timesfm.md +++ b/docs/source/en/model_doc/timesfm.md @@ -30,18 +30,18 @@ This model was contributed by [kashif](https://huggingface.co/kashif). The original code can be found [here](https://github.com/google-research/timesfm). -## TimesFMConfig +## TimesFmConfig -[[autodoc]] TimesFMConfig +[[autodoc]] TimesFmConfig -## TimesFMDecoder +## TimesFmDecoder -[[autodoc]] TimesFMDecoder +[[autodoc]] TimesFmDecoder - forward -## TimesFMModelForPrediction +## TimesFmModelForPrediction -[[autodoc]] TimesFMModelForPrediction +[[autodoc]] TimesFmModelForPrediction - forward diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 5468d65bb48a..435905e9ed62 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -297,7 +297,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) * [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel) * [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model) -* [TimesFM](https://huggingface.co/docs/transformers/model_doc/timesfm#transformers.TimesFMModel) +* [TimesFM](https://huggingface.co/docs/transformers/model_doc/timesfm#transformers.TimesFmModel) * [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) * [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel) * [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 16be7e9304bb..11de1ee695bc 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -791,7 +791,7 @@ ], "models.textnet": ["TextNetConfig"], "models.time_series_transformer": ["TimeSeriesTransformerConfig"], - "models.timesfm": ["TimesFMConfig"], + "models.timesfm": ["TimesFmConfig"], "models.timesformer": ["TimesformerConfig"], "models.timm_backbone": ["TimmBackboneConfig"], "models.timm_wrapper": ["TimmWrapperConfig"], @@ -3604,9 +3604,9 @@ ) _import_structure["models.timesfm"].extend( [ - "TimesFMDecoder", - "TimesFMModelForPrediction", - "TimesFMPreTrainedModel", + "TimesFmDecoder", + "TimesFmModelForPrediction", + "TimesFmPreTrainedModel", ] ) _import_structure["models.timesformer"].extend( @@ -5835,7 +5835,7 @@ from .models.time_series_transformer import ( TimeSeriesTransformerConfig, ) - from .models.timesfm import TimesFMConfig + from .models.timesfm import TimesFmConfig from .models.timesformer import ( TimesformerConfig, ) @@ -8188,9 +8188,9 @@ TimeSeriesTransformerPreTrainedModel, ) from .models.timesfm import ( - TimesFMDecoder, - TimesFMModelForPrediction, - TimesFMPreTrainedModel, + TimesFmDecoder, + TimesFmModelForPrediction, + TimesFmPreTrainedModel, ) from .models.timesformer import ( TimesformerForVideoClassification, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index fb4d5cc765ba..606d524125ea 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -281,7 +281,7 @@ ("tapas", "TapasConfig"), ("textnet", "TextNetConfig"), ("time_series_transformer", "TimeSeriesTransformerConfig"), - ("timesfm", "TimesFMConfig"), + ("timesfm", "TimesFmConfig"), ("timesformer", "TimesformerConfig"), ("timm_backbone", "TimmBackboneConfig"), ("timm_wrapper", "TimmWrapperConfig"), @@ -614,7 +614,7 @@ ("tapex", "TAPEX"), ("textnet", "TextNet"), ("time_series_transformer", "Time Series Transformer"), - ("timesfm", "TimesFM"), + ("timesfm", "TimesFm"), ("timesformer", "TimeSformer"), ("timm_backbone", "TimmBackbone"), ("timm_wrapper", "TimmWrapperModel"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index ab60552f447b..b47ddeb2a2fc 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -259,7 +259,7 @@ ("tapas", "TapasModel"), ("textnet", "TextNetModel"), ("time_series_transformer", "TimeSeriesTransformerModel"), - ("timesfm", "TimesFMModelForPrediction"), + ("timesfm", "TimesFmModelForPrediction"), ("timesformer", "TimesformerModel"), ("timm_backbone", "TimmBackbone"), ("timm_wrapper", "TimmWrapperModel"), @@ -368,7 +368,7 @@ ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), ("tapas", "TapasForMaskedLM"), - ("timesfm", "TimesFMModelForPrediction"), + ("timesfm", "TimesFmModelForPrediction"), ("transfo-xl", "TransfoXLLMHeadModel"), ("tvlt", "TvltForPreTraining"), ("unispeech", "UniSpeechForPreTraining"), diff --git a/src/transformers/models/timesfm/__init__.py b/src/transformers/models/timesfm/__init__.py index 51028a860782..0441d2ce1eda 100644 --- a/src/transformers/models/timesfm/__init__.py +++ b/src/transformers/models/timesfm/__init__.py @@ -21,7 +21,7 @@ ) -_import_structure = {"configuration_timesfm": ["TimesFMConfig", "TimesFMOnnxConfig"]} +_import_structure = {"configuration_timesfm": ["TimesFmConfig", "TimesFmOnnxConfig"]} try: if not is_torch_available(): @@ -30,13 +30,13 @@ pass else: _import_structure["modeling_timesfm"] = [ - "TimesFMModelForPrediction", - "TimesFMDecoder", - "TimesFMPreTrainedModel", + "TimesFmModelForPrediction", + "TimesFmDecoder", + "TimesFmPreTrainedModel", ] if TYPE_CHECKING: - from .configuration_timesfm import TimesFMConfig, TimesFMOnnxConfig + from .configuration_timesfm import TimesFmConfig, TimesFmOnnxConfig try: if not is_torch_available(): @@ -44,7 +44,7 @@ except OptionalDependencyNotAvailable: pass else: - from .modeling_timesfm import TimesFMDecoder, TimesFMModelForPrediction, TimesFMPreTrainedModel + from .modeling_timesfm import TimesFmDecoder, TimesFmModelForPrediction, TimesFmPreTrainedModel else: import sys diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 012315882957..62e4920ca14a 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -24,9 +24,9 @@ logger = logging.get_logger(__name__) -class TimesFMConfig(PretrainedConfig): +class TimesFmConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`TimesFMModelForPrediction`] or a [`TFTimesFMDecoder`]. It is used to + This is the configuration class to store the configuration of a [`TimesFmModelForPrediction`] or a [`TFTimesFmDecoder`]. It is used to instantiate a TimesFM model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the TimesFM [google/timesfm-1.0-200m](https://huggingface.co/google/timesfm-1.0-200m) architecture. @@ -123,7 +123,7 @@ def __init__( ) -class TimesFMOnnxConfig(OnnxSeq2SeqConfigWithPast): +class TimesFmOnnxConfig(OnnxSeq2SeqConfigWithPast): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: common_inputs = { diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py index 5309caca77a8..136efdd7d3e9 100644 --- a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py @@ -6,7 +6,7 @@ import timesfm import torch -from transformers import TimesFMConfig, TimesFMModelForPrediction +from transformers import TimesFmConfig, TimesFmModelForPrediction """ @@ -38,7 +38,7 @@ def write_model(model_path, safe_serialization=True, huggingface_repo_id="google checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=huggingface_repo_id), ) - timesfm_config = TimesFMConfig( + timesfm_config = TimesFmConfig( patch_len=tfm.hparams.input_patch_len, context_len=tfm.hparams.context_len, horizon_len=tfm.hparams.horizon_len, @@ -50,7 +50,7 @@ def write_model(model_path, safe_serialization=True, huggingface_repo_id="google use_positional_embedding=tfm.hparams.use_positional_embedding, ) timesfm_config.save_pretrained(tmp_model_path) - timesfm_model = TimesFMModelForPrediction(timesfm_config) + timesfm_model = TimesFmModelForPrediction(timesfm_config) # copy the weights from the original model to the new model making original_model = tfm._model @@ -159,7 +159,7 @@ def check_outputs(model_path, huggingface_repo_id): ) # Load converted model - converted_model = TimesFMModelForPrediction.from_pretrained( + converted_model = TimesFmModelForPrediction.from_pretrained( model_path, torch_dtype=torch.bfloat16, attn_implementation="sdpa", diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 15a1d901797c..c4b3ddac9d7a 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -17,7 +17,7 @@ import logging import math from dataclasses import dataclass -from typing import List, Sequence, Tuple +from typing import List, Sequence, Tuple, Optional, Union import torch import torch.nn as nn @@ -25,23 +25,23 @@ from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel -from .configuration_timesfm import TimesFMConfig +from .configuration_timesfm import TimesFmConfig @dataclass -class TimesFMDecoderOutput(BaseModelOutput): - loc: torch.Tensor | None = None - scale: torch.Tensor | None = None +class TimesFmDecoderOutput(BaseModelOutput): + loc: Optional[torch.Tensor] = None + scale: Optional[torch.Tensor] = None @dataclass -class TimesFMOutputForPrediction(BaseModelOutput): - mean_predictions: torch.Tensor | None = None - full_predictions: torch.Tensor | None = None - loss: torch.Tensor | float | None = None +class TimesFmOutputForPrediction(BaseModelOutput): + mean_predictions: Optional[torch.Tensor] = None + full_predictions: Optional[torch.Tensor] = None + loss: Optional[Union[torch.Tensor, float]] = None -class TimesFMTransformerMLP(nn.Module): +class TimesFmTransformerMLP(nn.Module): """Pax transformer MLP in pytorch.""" def __init__( @@ -64,7 +64,7 @@ def forward(self, x, paddings=None): return outputs + x -class TimesFMResidualBlock(nn.Module): +class TimesFmResidualBlock(nn.Module): """TimesFM residual block.""" def __init__(self, input_dims, hidden_dims, output_dims): @@ -91,7 +91,7 @@ def forward(self, x): return output + residual -class TimesFMRMSNorm(torch.nn.Module): +class TimesFmRMSNorm(torch.nn.Module): """Pax rms norm in pytorch.""" def __init__( @@ -117,7 +117,7 @@ def forward(self, x): return output.type_as(x) -class TimesFMPositionalEmbedding(nn.Module): +class TimesFmPositionalEmbedding(nn.Module): """Generates position embedding for a given 1-d sequence. Attributes: @@ -172,10 +172,10 @@ def forward(self, seq_length=None, position=None): return signal -class TimesFMAttention(nn.Module): - """Implements the attention used in TimesFM. One key diffrence is that there is _per_dim_scaling of the query.""" +class TimesFmAttention(nn.Module): + """Implements the attention used in TimesFM. One key difference is that there is _per_dim_scaling of the query.""" - def __init__(self, config: TimesFMConfig): + def __init__(self, config: TimesFmConfig): super().__init__() self.attention_dropout = config.attention_dropout @@ -205,11 +205,11 @@ def _scale_query(self, query: torch.Tensor) -> torch.Tensor: def forward( self, hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - kv_write_indices: torch.Tensor | None = None, - kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: Optional[torch.Tensor] = None, + kv_write_indices: Optional[torch.Tensor] = None, + kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, output_attentions: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states_shape = hidden_states.shape assert len(hidden_states_shape) == 3 @@ -268,17 +268,17 @@ def forward( return output, scores -class TimesFMSdpaAttention(TimesFMAttention): +class TimesFmSdpaAttention(TimesFmAttention): """TimesFM attention implementation using torch.nn.functional.scaled_dot_product_attention.""" def forward( self, hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - kv_write_indices: torch.Tensor | None = None, - kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: Optional[torch.Tensor] = None, + kv_write_indices: Optional[torch.Tensor] = None, + kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, output_attentions: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: if output_attentions: return super().forward( hidden_states=hidden_states, @@ -330,7 +330,7 @@ def forward( value = value.contiguous() # Run scaled dot-product attention - # Note: attention_mask should already be in the correct format from TimesFMStackedDecoder + # Note: attention_mask should already be in the correct format from TimesFmStackedDecoder attn_output = F.scaled_dot_product_attention( query, key, @@ -349,15 +349,15 @@ def forward( TIMESFM_ATTENTION_CLASSES = { - "eager": TimesFMAttention, - "sdpa": TimesFMSdpaAttention, + "eager": TimesFmAttention, + "sdpa": TimesFmSdpaAttention, } -class TimesFMDecoderLayer(nn.Module): +class TimesFmDecoderLayer(nn.Module): """Transformer layer.""" - def __init__(self, config: TimesFMConfig): + def __init__(self, config: TimesFmConfig): super().__init__() if config._attn_implementation not in TIMESFM_ATTENTION_CLASSES: @@ -365,16 +365,16 @@ def __init__(self, config: TimesFMConfig): attention_class = TIMESFM_ATTENTION_CLASSES[config._attn_implementation] self.self_attn = attention_class(config) - self.mlp = TimesFMTransformerMLP(config.model_dim, config.intermediate_size) - self.input_layernorm = TimesFMRMSNorm(config.model_dim, eps=config.rms_norm_eps) + self.mlp = TimesFmTransformerMLP(config.model_dim, config.intermediate_size) + self.input_layernorm = TimesFmRMSNorm(config.model_dim, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, paddings: torch.Tensor, - kv_write_indices: torch.Tensor | None = None, - kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, + kv_write_indices: Optional[torch.Tensor] = None, + kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, output_attentions: bool = False, ) -> torch.Tensor: # Self Attention @@ -395,20 +395,20 @@ def forward( return scores, hidden_states -class TimesFMStackedDecoder(nn.Module): +class TimesFmStackedDecoder(nn.Module): """Stacked transformer layer.""" - def __init__(self, config: TimesFMConfig): + def __init__(self, config: TimesFmConfig): super().__init__() - self.layers = nn.ModuleList([TimesFMDecoderLayer(config) for _ in range(config.num_layers)]) + self.layers = nn.ModuleList([TimesFmDecoderLayer(config) for _ in range(config.num_layers)]) def forward( self, hidden_states: torch.Tensor, paddings: torch.Tensor, - kv_write_indices: torch.Tensor | None = None, - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] | None = None, + kv_write_indices: Optional[torch.Tensor] = None, + kv_caches: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, output_attentions: bool = False, output_hidden_states: bool = False, ) -> BaseModelOutput: @@ -613,10 +613,10 @@ def expand_t(key_mask): return torch.minimum(a, b) # Element-wise minimum, similar to jnp.minimum -class TimesFMPreTrainedModel(PreTrainedModel): +class TimesFmPreTrainedModel(PreTrainedModel): """handles the loading for all models.""" - config_class = TimesFMConfig + config_class = TimesFmConfig base_model_prefix = "timesfm" main_input_name = "inputs" _supports_sdpa = True @@ -634,10 +634,10 @@ def _init_weights(self, module): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) - elif isinstance(module, TimesFMRMSNorm): + elif isinstance(module, TimesFmRMSNorm): nn.init.zeros_(module.weight) - elif isinstance(module, TimesFMTransformerMLP): + elif isinstance(module, TimesFmTransformerMLP): # Initialize gate projection module.gate_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor) if module.gate_proj.bias is not None: @@ -652,7 +652,7 @@ def _init_weights(self, module): nn.init.ones_(module.layer_norm.weight) nn.init.zeros_(module.layer_norm.bias) - elif isinstance(module, TimesFMAttention): + elif isinstance(module, TimesFmAttention): # Initialize qkv projection module.qkv_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor) if module.qkv_proj.bias is not None: @@ -666,7 +666,7 @@ def _init_weights(self, module): # Initialize scaling parameter nn.init.ones_(module.scaling) - elif isinstance(module, TimesFMResidualBlock): + elif isinstance(module, TimesFmResidualBlock): # Initialize hidden layer module.hidden_layer[0].weight.data.normal_(mean=0, std=self.config.initializer_factor) if module.hidden_layer[0].bias is not None: @@ -682,26 +682,26 @@ def _init_weights(self, module): if module.residual_layer.bias is not None: nn.init.zeros_(module.residual_layer.bias) - elif isinstance(module, TimesFMPositionalEmbedding): + elif isinstance(module, TimesFmPositionalEmbedding): pass -class TimesFMDecoder(TimesFMPreTrainedModel): +class TimesFmDecoder(TimesFmPreTrainedModel): """Patched time-series decoder without any specific output layer.""" - def __init__(self, config: TimesFMConfig): + def __init__(self, config: TimesFmConfig): super().__init__(config) self.config = config - self.input_ff_layer = TimesFMResidualBlock( + self.input_ff_layer = TimesFmResidualBlock( input_dims=2 * config.patch_len, output_dims=config.model_dim, hidden_dims=config.intermediate_size, ) self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.model_dim) - self.stacked_transformer = TimesFMStackedDecoder(config=config) + self.stacked_transformer = TimesFmStackedDecoder(config=config) if self.config.use_positional_embedding: - self.position_emb = TimesFMPositionalEmbedding( + self.position_emb = TimesFmPositionalEmbedding( embedding_dims=self.config.model_dim, ) @@ -772,7 +772,7 @@ def forward( output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, - ) -> TimesFMDecoderOutput | tuple[torch.Tensor, ...]: + ) -> Union[TimesFmDecoderOutput, tuple[torch.Tensor, ...]]: model_input, patched_padding, stats, _ = self._preprocess_input( input_ts=input_ts, input_padding=input_padding, @@ -792,7 +792,7 @@ def forward( all_hidden_states = None if return_dict: - return TimesFMDecoderOutput( + return TimesFmDecoderOutput( last_hidden_state=transformer_output.last_hidden_state, hidden_states=all_hidden_states, attentions=transformer_output.attentions if output_attentions else None, @@ -809,20 +809,20 @@ def forward( ) -class TimesFMModelForPrediction(TimesFMPreTrainedModel): +class TimesFmModelForPrediction(TimesFmPreTrainedModel): """TimesFM model for quantile and mean prediction.""" - def __init__(self, config: TimesFMConfig): + def __init__(self, config: TimesFmConfig): super().__init__(config) self.config = config self.context_len = config.context_len self.horizon_len = config.horizon_len - self.decoder = TimesFMDecoder(config) + self.decoder = TimesFmDecoder(config) # quantile and mean output - self.horizon_ff_layer = TimesFMResidualBlock( + self.horizon_ff_layer = TimesFmResidualBlock( input_dims=config.model_dim, output_dims=config.horizon_len * (1 + len(config.quantiles)), hidden_dims=config.intermediate_size, @@ -899,8 +899,8 @@ def decode( paddings: torch.Tensor, freq: torch.LongTensor, horizon_len: int, - output_patch_len: int | None = None, - max_len: int | None = None, + output_patch_len: Optional[int] = None, + max_len: Optional[int] = None, return_forecast_on_context: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, @@ -1005,16 +1005,16 @@ def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> to def forward( self, inputs: Sequence[torch.Tensor], - freq: Sequence[torch.Tensor | int] | None = None, - window_size: int | None = None, - future_target: torch.Tensor | None = None, - forecast_context_len: int | None = None, + freq: Optional[Sequence[Union[torch.Tensor,int]]] = None, + window_size: Optional[int] = None, + future_target: Optional[torch.Tensor] = None, + forecast_context_len: Optional[int] = None, return_forecast_on_context: bool = False, truncate_negative: bool = False, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - ) -> TimesFMOutputForPrediction | tuple[torch.Tensor, ...]: + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[TimesFmOutputForPrediction, tuple[torch.Tensor, ...]]: """Forecasts on a list of time series. Args: @@ -1032,10 +1032,10 @@ def forward( have non-negative values. output_attentions: Whether to return the attentions. output_hidden_states: Whether to return the hidden states. - return_dict: Whether to return a TimesFMOutputForPrediction object. + return_dict: Whether to return a TimesFmOutputForPrediction object. Returns: - A TimesFMOutputForPrediction object containing: + A TimesFmOutputForPrediction object containing: - the mean forecast of size (# inputs, # forecast horizon), - the full forecast (mean + quantiles) of size (# inputs, # forecast horizon, 1 + # quantiles). @@ -1109,7 +1109,7 @@ def forward( loss = mse_loss + quantile_loss if return_dict: - return TimesFMOutputForPrediction( + return TimesFmOutputForPrediction( last_hidden_state=last_hidden_state, attentions=all_attentions if output_attentions else None, hidden_states=all_hidden_states if output_hidden_states else None, diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index ea62dc6188d3..b49d6d0c0157 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -9207,21 +9207,21 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class TimesFMDecoder(metaclass=DummyObject): +class TimesFmDecoder(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class TimesFMModelForPrediction(metaclass=DummyObject): +class TimesFmModelForPrediction(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class TimesFMPreTrainedModel(metaclass=DummyObject): +class TimesFmPreTrainedModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index 46f758254bbd..c38b1ab07dd0 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2024 Google TimesFM Authors and HuggingFace Inc. team. +# Copyright 2024 Google LLC and HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ import torch from huggingface_hub import hf_hub_download -from transformers import TimesFMConfig, is_torch_available +from transformers import TimesFmConfig, is_torch_available from transformers.testing_utils import require_torch, slow, torch_device from transformers.utils import is_torch_fx_available @@ -33,12 +33,12 @@ pass if is_torch_available(): - from transformers import TimesFMModelForPrediction + from transformers import TimesFmModelForPrediction TOLERANCE = 1e-4 -class TimesFMModelTester: +class TimesFmModelTester: def __init__( self, parent, @@ -84,10 +84,10 @@ def __init__( self.hidden_size = model_dim def get_large_model_config(self): - return TimesFMConfig.from_pretrained("google/timesfm-1.0-200m-pytorch") + return TimesFmConfig.from_pretrained("google/timesfm-1.0-200m-pytorch") def get_config(self): - return TimesFMConfig( + return TimesFmConfig( patch_len=self.patch_len, context_len=self.context_len, horizon_len=self.horizon_len, @@ -129,9 +129,9 @@ def prepare_config_and_inputs_for_common(self): @require_torch -class TimesFMModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (TimesFMModelForPrediction,) if is_torch_available() else () - all_generative_model_classes = (TimesFMModelForPrediction,) if is_torch_available() else () +class TimesFmModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (TimesFmModelForPrediction,) if is_torch_available() else () + all_generative_model_classes = (TimesFmModelForPrediction,) if is_torch_available() else () all_parallelizable_model_classes = () fx_compatible = False test_pruning = False @@ -141,12 +141,12 @@ class TimesFMModelTest(ModelTesterMixin, unittest.TestCase): test_inputs_embeds = False def setUp(self): - self.model_tester = TimesFMModelTester(self) - self.config_tester = ConfigTester(self, config_class=TimesFMConfig) + self.model_tester = TimesFmModelTester(self) + self.config_tester = ConfigTester(self, config_class=TimesFmConfig) def test_create_and_run_model(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model = TimesFMModelForPrediction(config) + model = TimesFmModelForPrediction(config) model.to(torch_device) model.eval() results = model(**inputs_dict) @@ -162,15 +162,15 @@ def test_headmasking(self): # the main input name is `inputs` def test_model_main_input_name(self): - model_signature = inspect.signature(getattr(TimesFMModelForPrediction, "forward")) + model_signature = inspect.signature(getattr(TimesFmModelForPrediction, "forward")) # The main input is the name of the argument after `self` observed_main_input_name = list(model_signature.parameters.keys())[1] - self.assertEqual(TimesFMModelForPrediction.main_input_name, observed_main_input_name) + self.assertEqual(TimesFmModelForPrediction.main_input_name, observed_main_input_name) @require_torch @slow -class TimesFMModelIntegrationTests(unittest.TestCase): +class TimesFmModelIntegrationTests(unittest.TestCase): @classmethod def load_batch(cls, filename="train-batch.pt"): file = hf_hub_download( @@ -180,7 +180,7 @@ def load_batch(cls, filename="train-batch.pt"): return batch def test_inference_no_head(self): - model = TimesFMModelForPrediction.from_pretrained("huggingface/timesfm-tourism-monthly").to(torch_device) + model = TimesFmModelForPrediction.from_pretrained("huggingface/timesfm-tourism-monthly").to(torch_device) batch = self.load_batch() with torch.no_grad(): inputs = batch["past_values"] From b21ec5025dfbf897331f9bcb025dd20002607df8 Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Mon, 13 Jan 2025 19:12:57 -0800 Subject: [PATCH 064/242] consolidate decode func --- .../models/timesfm/modeling_timesfm.py | 176 ++++++------------ 1 file changed, 59 insertions(+), 117 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index c4b3ddac9d7a..c8e3c28589e6 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -893,107 +893,6 @@ def _reverse_transform(self, outputs: torch.Tensor, stats: tuple[torch.Tensor, t mu, sigma = stats return outputs * sigma[:, None, None, None] + mu[:, None, None, None] - def decode( - self, - input_ts: torch.Tensor, - paddings: torch.Tensor, - freq: torch.LongTensor, - horizon_len: int, - output_patch_len: Optional[int] = None, - max_len: Optional[int] = None, - return_forecast_on_context: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - ) -> tuple[torch.Tensor, ...]: - """Auto-regressive decoding without caching. - - Args: - input_ts: input time-series and paddings. Time-series shape B x C. - paddings: padding shape B x (C + H) where H is the prediction length. - freq: frequency shape B x 1 - horizon_len: prediction length. - output_patch_len: output length to be fetched from one step of - auto-regressive decoding. - max_len: maximum training context length. If None, then we use the length - of the initial context as max length. - return_forecast_on_context: whether to return the model forecast on the - context except the first input patch. - - Returns: - Tuple of two forecasting results: - - Point (mean) output predictions as a tensor with shape B x H'. - - Full predictions (mean and quantiles) as a tensor with shape - B x H' x (1 + # quantiles). - In particular, if return_forecast_on_context is True, H' is H plus - the forecastable context length, i.e. context_len - (first) patch_len. - - Raises: - ValueError: If the paddings do not match the input + horizon_len. - """ - final_out = input_ts - context_len = final_out.shape[1] - full_outputs = [] - - if paddings.shape[1] != final_out.shape[1] + horizon_len: - raise ValueError( - "Length of paddings must match length of input + horizon_len:" - f" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}" - ) - if output_patch_len is None: - output_patch_len = self.config.horizon_len - if max_len is None: - max_len = context_len - - num_decode_patches = (horizon_len + output_patch_len - 1) // output_patch_len - for step_index in range(num_decode_patches): - current_padding = paddings[:, 0 : final_out.shape[1]] - input_ts = final_out[:, -max_len:] - input_padding = current_padding[:, -max_len:] - decoder_output = self.decoder( - input_ts, - input_padding, - freq, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - fprop_outputs = self._postprocess_output( - decoder_output.last_hidden_state, - (decoder_output.loc, decoder_output.scale), - ) - - if return_forecast_on_context and step_index == 0: - # For the first decodings step, collect the model forecast on the - # context except the unavailable first input batch forecast. - new_full_ts = fprop_outputs[:, :-1, : self.config.patch_len, :] - # We have to use reshape and not view for non-contiguous memory - new_full_ts = new_full_ts.reshape(new_full_ts.size(0), -1, new_full_ts.size(3)) - - full_outputs.append(new_full_ts) - - # (full batch, last patch, output_patch_len, index of mean forecast = 0) - new_ts = fprop_outputs[:, -1, :output_patch_len, 0] - new_full_ts = fprop_outputs[:, -1, :output_patch_len, :] - # (full batch, last patch, output_patch_len, all output indices) - full_outputs.append(new_full_ts) - final_out = torch.concatenate([final_out, new_ts], axis=-1) - - if return_forecast_on_context: - # `full_outputs` indexing starts at after the first input patch. - full_outputs = torch.concatenate(full_outputs, axis=1)[ - :, : (context_len - self.config.patch_len + horizon_len), : - ] - else: - # `full_outputs` indexing starts at the forecast horizon. - full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0:horizon_len, :] - - return ( - full_outputs[:, :, 0], - full_outputs, - decoder_output.last_hidden_state, - decoder_output.attentions, - decoder_output.hidden_states, - ) - def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: losses = [] for i, q in enumerate(self.config.quantiles): @@ -1083,18 +982,61 @@ def forward( input_ts = input_ts.to(device) input_padding = input_padding.to(device) inp_freq = inp_freq.to(device) + + final_out = input_ts + context_len = final_out.shape[1] + full_outputs = [] - mean_outputs, full_outputs, last_hidden_state, all_attentions, all_hidden_states = self.decode( - input_ts=input_ts, - paddings=input_padding, - freq=inp_freq, - horizon_len=self.horizon_len, - return_forecast_on_context=return_forecast_on_context, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - max_len=fcontext_len, - ) + if input_padding.shape[1] != final_out.shape[1] + self.horizon_len: + raise ValueError( + "Length of paddings must match length of input + horizon_len:" + f" {input_padding.shape[1]} != {final_out.shape[1]} + {self.horizon_len}" + ) + output_patch_len = self.config.horizon_len + + num_decode_patches = (self.horizon_len + output_patch_len - 1) // output_patch_len + for step_index in range(num_decode_patches): + current_padding = input_padding[:, 0 : final_out.shape[1]] + input_ts = final_out[:, -fcontext_len:] + input_padding = current_padding[:, -fcontext_len:] + decoder_output = self.decoder( + input_ts, + input_padding, + inp_freq, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + fprop_outputs = self._postprocess_output( + decoder_output.last_hidden_state, + (decoder_output.loc, decoder_output.scale), + ) + + if return_forecast_on_context and step_index == 0: + # For the first decodings step, collect the model forecast on the + # context except the unavailable first input batch forecast. + new_full_ts = fprop_outputs[:, :-1, : self.config.patch_len, :] + # We have to use reshape and not view for non-contiguous memory + new_full_ts = new_full_ts.reshape(new_full_ts.size(0), -1, new_full_ts.size(3)) + + full_outputs.append(new_full_ts) + + # (full batch, last patch, output_patch_len, index of mean forecast = 0) + new_ts = fprop_outputs[:, -1, :output_patch_len, 0] + new_full_ts = fprop_outputs[:, -1, :output_patch_len, :] + # (full batch, last patch, output_patch_len, all output indices) + full_outputs.append(new_full_ts) + final_out = torch.concatenate([final_out, new_ts], axis=-1) + + if return_forecast_on_context: + # `full_outputs` indexing starts at after the first input patch. + full_outputs = torch.concatenate(full_outputs, axis=1)[ + :, : (context_len - self.config.patch_len + self.horizon_len), : + ] + else: + # `full_outputs` indexing starts at the forecast horizon. + full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0:self.horizon_len, :] + mean_outputs = full_outputs[:, :, 0] if window_size is not None: mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] @@ -1110,18 +1052,18 @@ def forward( if return_dict: return TimesFmOutputForPrediction( - last_hidden_state=last_hidden_state, - attentions=all_attentions if output_attentions else None, - hidden_states=all_hidden_states if output_hidden_states else None, + last_hidden_state=decoder_output.last_hidden_state, + attentions=decoder_output.all_attentions if output_attentions else None, + hidden_states=decoder_output.all_hidden_states if output_hidden_states else None, mean_predictions=mean_outputs, full_predictions=full_outputs, loss=loss, ) else: - return_tuple = [last_hidden_state] + return_tuple = [decoder_output.last_hidden_state] if output_hidden_states: - return_tuple.append(all_hidden_states) + return_tuple.append(decoder_output.all_hidden_states) if output_attentions: - return_tuple.append(all_attentions) + return_tuple.append(decoder_output.all_attentions) return_tuple += [mean_outputs, full_outputs, loss] return tuple(return_tuple) From e162102bcedfa1b8993f235a307f4b431ef1c6ed Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Wed, 5 Feb 2025 18:38:12 -0800 Subject: [PATCH 065/242] more readable code for weight conversion --- docs/source/en/model_doc/timesfm.md | 3 - src/transformers/models/auto/modeling_auto.py | 2 +- .../timesfm/convert_timesfm_orignal_to_hf.py | 151 +++++++++--------- 3 files changed, 73 insertions(+), 83 deletions(-) diff --git a/docs/source/en/model_doc/timesfm.md b/docs/source/en/model_doc/timesfm.md index 3edf1aedbb0a..144b29769faf 100644 --- a/docs/source/en/model_doc/timesfm.md +++ b/docs/source/en/model_doc/timesfm.md @@ -43,6 +43,3 @@ The original code can be found [here](https://github.com/google-research/timesfm [[autodoc]] TimesFmModelForPrediction - forward - - - diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index ca36385ab448..603f86badddf 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -264,7 +264,7 @@ ("tapas", "TapasModel"), ("textnet", "TextNetModel"), ("time_series_transformer", "TimeSeriesTransformerModel"), - ("timesfm", "TimesFmModelForPrediction"), + ("timesfm", "TimesFmModel"), ("timesformer", "TimesformerModel"), ("timm_backbone", "TimmBackbone"), ("timm_wrapper", "TimmWrapperModel"), diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py index 136efdd7d3e9..d3db52aa8dbf 100644 --- a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py @@ -1,6 +1,7 @@ import argparse import os import shutil +import re import numpy as np import timesfm @@ -19,6 +20,19 @@ """ +def get_nested_attr(obj, key): + """Recursively retrieves an attribute from an object, handling list/tuple indexing if present.""" + parts = key.split('.') + for part in parts: + match = re.match(r"(.*)\[(\d+)\]", part) # Handle list indexing like `layers[0]` + if match: + attr_name, index = match.groups() + obj = getattr(obj, attr_name)[int(index)] # Access list/tuple element + else: + obj = getattr(obj, part) # Regular attribute access + return obj + + def write_model(model_path, safe_serialization=True, huggingface_repo_id="google/timesfm-2.0-500m-pytorch"): os.makedirs(model_path, exist_ok=True) tmp_model_path = os.path.join(model_path, "tmp") @@ -54,85 +68,64 @@ def write_model(model_path, safe_serialization=True, huggingface_repo_id="google # copy the weights from the original model to the new model making original_model = tfm._model - - # Map decoder input_ff_layer - timesfm_model.decoder.input_ff_layer.hidden_layer[0].weight.data = original_model.input_ff_layer.hidden_layer[ - 0 - ].weight.data - timesfm_model.decoder.input_ff_layer.hidden_layer[0].bias.data = original_model.input_ff_layer.hidden_layer[ - 0 - ].bias.data - timesfm_model.decoder.input_ff_layer.output_layer.weight.data = ( - original_model.input_ff_layer.output_layer.weight.data - ) - timesfm_model.decoder.input_ff_layer.output_layer.bias.data = original_model.input_ff_layer.output_layer.bias.data - timesfm_model.decoder.input_ff_layer.residual_layer.weight.data = ( - original_model.input_ff_layer.residual_layer.weight.data - ) - timesfm_model.decoder.input_ff_layer.residual_layer.bias.data = ( - original_model.input_ff_layer.residual_layer.bias.data - ) - - # Map freq embedding - timesfm_model.decoder.freq_emb.weight.data = original_model.freq_emb.weight.data - - # Map horizon_ff_layer - timesfm_model.horizon_ff_layer.hidden_layer[0].weight.data = original_model.horizon_ff_layer.hidden_layer[ - 0 - ].weight.data - timesfm_model.horizon_ff_layer.hidden_layer[0].bias.data = original_model.horizon_ff_layer.hidden_layer[ - 0 - ].bias.data - timesfm_model.horizon_ff_layer.output_layer.weight.data = original_model.horizon_ff_layer.output_layer.weight.data - timesfm_model.horizon_ff_layer.output_layer.bias.data = original_model.horizon_ff_layer.output_layer.bias.data - timesfm_model.horizon_ff_layer.residual_layer.weight.data = ( - original_model.horizon_ff_layer.residual_layer.weight.data - ) - timesfm_model.horizon_ff_layer.residual_layer.bias.data = original_model.horizon_ff_layer.residual_layer.bias.data - - # Map transformer layers - for i in range(len(timesfm_model.decoder.stacked_transformer.layers)): - # Map attention layers - timesfm_model.decoder.stacked_transformer.layers[ - i - ].self_attn.qkv_proj.weight.data = original_model.stacked_transformer.layers[i].self_attn.qkv_proj.weight.data - timesfm_model.decoder.stacked_transformer.layers[ - i - ].self_attn.qkv_proj.bias.data = original_model.stacked_transformer.layers[i].self_attn.qkv_proj.bias.data - timesfm_model.decoder.stacked_transformer.layers[ - i - ].self_attn.o_proj.weight.data = original_model.stacked_transformer.layers[i].self_attn.o_proj.weight.data - timesfm_model.decoder.stacked_transformer.layers[ - i - ].self_attn.o_proj.bias.data = original_model.stacked_transformer.layers[i].self_attn.o_proj.bias.data - timesfm_model.decoder.stacked_transformer.layers[ - i - ].self_attn.scaling.data = original_model.stacked_transformer.layers[i].self_attn.scaling.data - - # Map MLP layers - timesfm_model.decoder.stacked_transformer.layers[ - i - ].mlp.gate_proj.weight.data = original_model.stacked_transformer.layers[i].mlp.gate_proj.weight.data - timesfm_model.decoder.stacked_transformer.layers[ - i - ].mlp.gate_proj.bias.data = original_model.stacked_transformer.layers[i].mlp.gate_proj.bias.data - timesfm_model.decoder.stacked_transformer.layers[ - i - ].mlp.down_proj.weight.data = original_model.stacked_transformer.layers[i].mlp.down_proj.weight.data - timesfm_model.decoder.stacked_transformer.layers[ - i - ].mlp.down_proj.bias.data = original_model.stacked_transformer.layers[i].mlp.down_proj.bias.data - timesfm_model.decoder.stacked_transformer.layers[ - i - ].mlp.layer_norm.weight.data = original_model.stacked_transformer.layers[i].mlp.layer_norm.weight.data - timesfm_model.decoder.stacked_transformer.layers[ - i - ].mlp.layer_norm.bias.data = original_model.stacked_transformer.layers[i].mlp.layer_norm.bias.data - - # Map layer norms - timesfm_model.decoder.stacked_transformer.layers[ - i - ].input_layernorm.weight.data = original_model.stacked_transformer.layers[i].input_layernorm.weight.data + + # mapping of the layers from the original model to the transformer model + MODEL_LAYER_MAPPING = { + "input_ff_layer.hidden_layer[0].weight": "decoder.input_ff_layer.hidden_layer[0].weight", + "input_ff_layer.hidden_layer[0].bias": "decoder.input_ff_layer.hidden_layer[0].bias", + "input_ff_layer.output_layer.weight": "decoder.input_ff_layer.output_layer.weight", + "input_ff_layer.output_layer.bias": "decoder.input_ff_layer.output_layer.bias", + "input_ff_layer.residual_layer.weight": "decoder.input_ff_layer.residual_layer.weight", + "input_ff_layer.residual_layer.bias": "decoder.input_ff_layer.residual_layer.bias", + + "freq_emb.weight": "decoder.freq_emb.weight", + + "horizon_ff_layer.hidden_layer[0].weight": "horizon_ff_layer.hidden_layer[0].weight", + "horizon_ff_layer.hidden_layer[0].bias": "horizon_ff_layer.hidden_layer[0].bias", + "horizon_ff_layer.output_layer.weight": "horizon_ff_layer.output_layer.weight", + "horizon_ff_layer.output_layer.bias": "horizon_ff_layer.output_layer.bias", + "horizon_ff_layer.residual_layer.weight": "horizon_ff_layer.residual_layer.weight", + "horizon_ff_layer.residual_layer.bias": "horizon_ff_layer.residual_layer.bias", + } + + + TRANSFORMER_LAYER_MAPPING = { + "stacked_transformer.layers[{i}].self_attn.qkv_proj.weight": "decoder.stacked_transformer.layers[{i}].self_attn.qkv_proj.weight", + "stacked_transformer.layers[{i}].self_attn.qkv_proj.bias": "decoder.stacked_transformer.layers[{i}].self_attn.qkv_proj.bias", + "stacked_transformer.layers[{i}].self_attn.o_proj.weight": "decoder.stacked_transformer.layers[{i}].self_attn.o_proj.weight", + "stacked_transformer.layers[{i}].self_attn.o_proj.bias": "decoder.stacked_transformer.layers[{i}].self_attn.o_proj.bias", + "stacked_transformer.layers[{i}].self_attn.scaling": "decoder.stacked_transformer.layers[{i}].self_attn.scaling", + + "stacked_transformer.layers[{i}].mlp.gate_proj.weight": "decoder.stacked_transformer.layers[{i}].mlp.gate_proj.weight", + "stacked_transformer.layers[{i}].mlp.gate_proj.bias": "decoder.stacked_transformer.layers[{i}].mlp.gate_proj.bias", + "stacked_transformer.layers[{i}].mlp.down_proj.weight": "decoder.stacked_transformer.layers[{i}].mlp.down_proj.weight", + "stacked_transformer.layers[{i}].mlp.down_proj.bias": "decoder.stacked_transformer.layers[{i}].mlp.down_proj.bias", + "stacked_transformer.layers[{i}].mlp.layer_norm.weight": "decoder.stacked_transformer.layers[{i}].mlp.layer_norm.weight", + "stacked_transformer.layers[{i}].mlp.layer_norm.bias": "decoder.stacked_transformer.layers[{i}].mlp.layer_norm.bias", + + "stacked_transformer.layers[{i}].input_layernorm.weight": "decoder.stacked_transformer.layers[{i}].input_layernorm.weight", + } + + for old_key, new_key in MODEL_LAYER_MAPPING.items(): + try: + old_attr = get_nested_attr(original_model, old_key) # Get tensor from original model + new_attr = get_nested_attr(timesfm_model, new_key) # Get corresponding attribute in new model + new_attr.data.copy_(old_attr.data) # Copy data + except AttributeError: + print(f"Skipping {old_key} (not found in original model).") + + num_layers = len(timesfm_model.decoder.stacked_transformer.layers) + for i in range(num_layers): + for old_template, new_template in TRANSFORMER_LAYER_MAPPING.items(): + old_key = old_template.format(i=i) + new_key = new_template.format(i=i) + + try: + old_attr = get_nested_attr(original_model, old_key) # Get tensor from original model + new_attr = get_nested_attr(timesfm_model, new_key) # Get corresponding attribute in new model + new_attr.data.copy_(old_attr.data) # Copy data + except AttributeError: + print(f"Skipping {old_key} (not found in original model).") timesfm_model.save_pretrained(model_path, safe_serialization=safe_serialization) shutil.rmtree(tmp_model_path) From e7531e1fff8e50abb461caac18320de9447f577b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Feb 2025 19:51:35 +0100 Subject: [PATCH 066/242] fix-copies --- docs/source/en/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/index.md b/docs/source/en/index.md index eded39d8f9d9..f0323c14bcb9 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -338,7 +338,7 @@ Flax), PyTorch, and/or TensorFlow. | [TAPEX](model_doc/tapex) | ✅ | ✅ | ✅ | | [TextNet](model_doc/textnet) | ✅ | ❌ | ❌ | | [Time Series Transformer](model_doc/time_series_transformer) | ✅ | ❌ | ❌ | -| [TimesFM](model_doc/timesfm) | ✅ | ❌ | ❌ | +| [TimesFm](model_doc/timesfm) | ✅ | ❌ | ❌ | | [TimeSformer](model_doc/timesformer) | ✅ | ❌ | ❌ | | [TimmWrapperModel](model_doc/timm_wrapper) | ✅ | ❌ | ❌ | | [Trajectory Transformer](model_doc/trajectory_transformer) | ✅ | ❌ | ❌ | From 0cfb2c35fb82f0a2072d1e1b9f6dbc73b9531924 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Feb 2025 20:03:57 +0100 Subject: [PATCH 067/242] simpler init --- src/transformers/models/timesfm/__init__.py | 39 ++++--------------- .../models/timesfm/configuration_timesfm.py | 2 +- .../timesfm/convert_timesfm_orignal_to_hf.py | 6 +-- .../models/timesfm/modeling_timesfm.py | 6 +-- 4 files changed, 14 insertions(+), 39 deletions(-) diff --git a/src/transformers/models/timesfm/__init__.py b/src/transformers/models/timesfm/__init__.py index 0441d2ce1eda..12f1541b9c54 100644 --- a/src/transformers/models/timesfm/__init__.py +++ b/src/transformers/models/timesfm/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,42 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import TYPE_CHECKING -from ...utils import ( - OptionalDependencyNotAvailable, - _LazyModule, - is_torch_available, -) - - -_import_structure = {"configuration_timesfm": ["TimesFmConfig", "TimesFmOnnxConfig"]} +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure -try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["modeling_timesfm"] = [ - "TimesFmModelForPrediction", - "TimesFmDecoder", - "TimesFmPreTrainedModel", - ] if TYPE_CHECKING: - from .configuration_timesfm import TimesFmConfig, TimesFmOnnxConfig - - try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .modeling_timesfm import TimesFmDecoder, TimesFmModelForPrediction, TimesFmPreTrainedModel - + from .configuration_timesfm import * + from .modeling_timesfm import * else: import sys - sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 62e4920ca14a..086cfad5c33d 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2024 Google LLC and HuggingFace Inc. team. +# Copyright 2025 Google LLC and HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py index d3db52aa8dbf..c3219d9a6097 100644 --- a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py @@ -1,7 +1,7 @@ import argparse import os -import shutil import re +import shutil import numpy as np import timesfm @@ -68,7 +68,7 @@ def write_model(model_path, safe_serialization=True, huggingface_repo_id="google # copy the weights from the original model to the new model making original_model = tfm._model - + # mapping of the layers from the original model to the transformer model MODEL_LAYER_MAPPING = { "input_ff_layer.hidden_layer[0].weight": "decoder.input_ff_layer.hidden_layer[0].weight", @@ -88,7 +88,7 @@ def write_model(model_path, safe_serialization=True, huggingface_repo_id="google "horizon_ff_layer.residual_layer.bias": "horizon_ff_layer.residual_layer.bias", } - + TRANSFORMER_LAYER_MAPPING = { "stacked_transformer.layers[{i}].self_attn.qkv_proj.weight": "decoder.stacked_transformer.layers[{i}].self_attn.qkv_proj.weight", "stacked_transformer.layers[{i}].self_attn.qkv_proj.bias": "decoder.stacked_transformer.layers[{i}].self_attn.qkv_proj.bias", diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index c8e3c28589e6..dc6a44b5341b 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2024 Google LLC and HuggingFace Inc. team. +# Copyright 2025 Google LLC and HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ import logging import math from dataclasses import dataclass -from typing import List, Sequence, Tuple, Optional, Union +from typing import List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -982,7 +982,7 @@ def forward( input_ts = input_ts.to(device) input_padding = input_padding.to(device) inp_freq = inp_freq.to(device) - + final_out = input_ts context_len = final_out.shape[1] full_outputs = [] From 2e29e5fcbf04b244919c8e177c12286a56f5c696 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Feb 2025 20:07:25 +0100 Subject: [PATCH 068/242] renaem TimesFmMLP --- src/transformers/models/timesfm/modeling_timesfm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index dc6a44b5341b..79a852ab8c01 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -41,8 +41,8 @@ class TimesFmOutputForPrediction(BaseModelOutput): loss: Optional[Union[torch.Tensor, float]] = None -class TimesFmTransformerMLP(nn.Module): - """Pax transformer MLP in pytorch.""" +class TimesFmMLP(nn.Module): + """Pax MLP in pytorch.""" def __init__( self, @@ -365,7 +365,7 @@ def __init__(self, config: TimesFmConfig): attention_class = TIMESFM_ATTENTION_CLASSES[config._attn_implementation] self.self_attn = attention_class(config) - self.mlp = TimesFmTransformerMLP(config.model_dim, config.intermediate_size) + self.mlp = TimesFmMLP(config.model_dim, config.intermediate_size) self.input_layernorm = TimesFmRMSNorm(config.model_dim, eps=config.rms_norm_eps) def forward( @@ -637,7 +637,7 @@ def _init_weights(self, module): elif isinstance(module, TimesFmRMSNorm): nn.init.zeros_(module.weight) - elif isinstance(module, TimesFmTransformerMLP): + elif isinstance(module, TimesFmMLP): # Initialize gate projection module.gate_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor) if module.gate_proj.bias is not None: From 5dc29270e2d1ac9a02f50f785cf9f8b01155caf3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Feb 2025 20:34:12 +0100 Subject: [PATCH 069/242] use T5LayerNorm --- .../models/timesfm/configuration_timesfm.py | 3 + .../timesfm/convert_timesfm_orignal_to_hf.py | 57 +++++++++---------- .../models/timesfm/modeling_timesfm.py | 46 +++++++-------- 3 files changed, 52 insertions(+), 54 deletions(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 086cfad5c33d..81ce310fe098 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -152,3 +152,6 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: @property def default_onnx_opset(self) -> int: return 13 + + +__all__ = ["TimesFmConfig"] diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py index c3219d9a6097..f1450fda6910 100644 --- a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py @@ -22,7 +22,7 @@ def get_nested_attr(obj, key): """Recursively retrieves an attribute from an object, handling list/tuple indexing if present.""" - parts = key.split('.') + parts = key.split(".") for part in parts: match = re.match(r"(.*)\[(\d+)\]", part) # Handle list indexing like `layers[0]` if match: @@ -71,39 +71,34 @@ def write_model(model_path, safe_serialization=True, huggingface_repo_id="google # mapping of the layers from the original model to the transformer model MODEL_LAYER_MAPPING = { - "input_ff_layer.hidden_layer[0].weight": "decoder.input_ff_layer.hidden_layer[0].weight", - "input_ff_layer.hidden_layer[0].bias": "decoder.input_ff_layer.hidden_layer[0].bias", - "input_ff_layer.output_layer.weight": "decoder.input_ff_layer.output_layer.weight", - "input_ff_layer.output_layer.bias": "decoder.input_ff_layer.output_layer.bias", - "input_ff_layer.residual_layer.weight": "decoder.input_ff_layer.residual_layer.weight", - "input_ff_layer.residual_layer.bias": "decoder.input_ff_layer.residual_layer.bias", - - "freq_emb.weight": "decoder.freq_emb.weight", - - "horizon_ff_layer.hidden_layer[0].weight": "horizon_ff_layer.hidden_layer[0].weight", - "horizon_ff_layer.hidden_layer[0].bias": "horizon_ff_layer.hidden_layer[0].bias", - "horizon_ff_layer.output_layer.weight": "horizon_ff_layer.output_layer.weight", - "horizon_ff_layer.output_layer.bias": "horizon_ff_layer.output_layer.bias", - "horizon_ff_layer.residual_layer.weight": "horizon_ff_layer.residual_layer.weight", - "horizon_ff_layer.residual_layer.bias": "horizon_ff_layer.residual_layer.bias", + "input_ff_layer.hidden_layer[0].weight": "decoder.input_ff_layer.hidden_layer[0].weight", + "input_ff_layer.hidden_layer[0].bias": "decoder.input_ff_layer.hidden_layer[0].bias", + "input_ff_layer.output_layer.weight": "decoder.input_ff_layer.output_layer.weight", + "input_ff_layer.output_layer.bias": "decoder.input_ff_layer.output_layer.bias", + "input_ff_layer.residual_layer.weight": "decoder.input_ff_layer.residual_layer.weight", + "input_ff_layer.residual_layer.bias": "decoder.input_ff_layer.residual_layer.bias", + "freq_emb.weight": "decoder.freq_emb.weight", + "horizon_ff_layer.hidden_layer[0].weight": "horizon_ff_layer.hidden_layer[0].weight", + "horizon_ff_layer.hidden_layer[0].bias": "horizon_ff_layer.hidden_layer[0].bias", + "horizon_ff_layer.output_layer.weight": "horizon_ff_layer.output_layer.weight", + "horizon_ff_layer.output_layer.bias": "horizon_ff_layer.output_layer.bias", + "horizon_ff_layer.residual_layer.weight": "horizon_ff_layer.residual_layer.weight", + "horizon_ff_layer.residual_layer.bias": "horizon_ff_layer.residual_layer.bias", } - TRANSFORMER_LAYER_MAPPING = { - "stacked_transformer.layers[{i}].self_attn.qkv_proj.weight": "decoder.stacked_transformer.layers[{i}].self_attn.qkv_proj.weight", - "stacked_transformer.layers[{i}].self_attn.qkv_proj.bias": "decoder.stacked_transformer.layers[{i}].self_attn.qkv_proj.bias", - "stacked_transformer.layers[{i}].self_attn.o_proj.weight": "decoder.stacked_transformer.layers[{i}].self_attn.o_proj.weight", - "stacked_transformer.layers[{i}].self_attn.o_proj.bias": "decoder.stacked_transformer.layers[{i}].self_attn.o_proj.bias", - "stacked_transformer.layers[{i}].self_attn.scaling": "decoder.stacked_transformer.layers[{i}].self_attn.scaling", - - "stacked_transformer.layers[{i}].mlp.gate_proj.weight": "decoder.stacked_transformer.layers[{i}].mlp.gate_proj.weight", - "stacked_transformer.layers[{i}].mlp.gate_proj.bias": "decoder.stacked_transformer.layers[{i}].mlp.gate_proj.bias", - "stacked_transformer.layers[{i}].mlp.down_proj.weight": "decoder.stacked_transformer.layers[{i}].mlp.down_proj.weight", - "stacked_transformer.layers[{i}].mlp.down_proj.bias": "decoder.stacked_transformer.layers[{i}].mlp.down_proj.bias", - "stacked_transformer.layers[{i}].mlp.layer_norm.weight": "decoder.stacked_transformer.layers[{i}].mlp.layer_norm.weight", - "stacked_transformer.layers[{i}].mlp.layer_norm.bias": "decoder.stacked_transformer.layers[{i}].mlp.layer_norm.bias", - - "stacked_transformer.layers[{i}].input_layernorm.weight": "decoder.stacked_transformer.layers[{i}].input_layernorm.weight", + "stacked_transformer.layers[{i}].self_attn.qkv_proj.weight": "decoder.stacked_transformer.layers[{i}].self_attn.qkv_proj.weight", + "stacked_transformer.layers[{i}].self_attn.qkv_proj.bias": "decoder.stacked_transformer.layers[{i}].self_attn.qkv_proj.bias", + "stacked_transformer.layers[{i}].self_attn.o_proj.weight": "decoder.stacked_transformer.layers[{i}].self_attn.o_proj.weight", + "stacked_transformer.layers[{i}].self_attn.o_proj.bias": "decoder.stacked_transformer.layers[{i}].self_attn.o_proj.bias", + "stacked_transformer.layers[{i}].self_attn.scaling": "decoder.stacked_transformer.layers[{i}].self_attn.scaling", + "stacked_transformer.layers[{i}].mlp.gate_proj.weight": "decoder.stacked_transformer.layers[{i}].mlp.gate_proj.weight", + "stacked_transformer.layers[{i}].mlp.gate_proj.bias": "decoder.stacked_transformer.layers[{i}].mlp.gate_proj.bias", + "stacked_transformer.layers[{i}].mlp.down_proj.weight": "decoder.stacked_transformer.layers[{i}].mlp.down_proj.weight", + "stacked_transformer.layers[{i}].mlp.down_proj.bias": "decoder.stacked_transformer.layers[{i}].mlp.down_proj.bias", + "stacked_transformer.layers[{i}].mlp.layer_norm.weight": "decoder.stacked_transformer.layers[{i}].mlp.layer_norm.weight", + "stacked_transformer.layers[{i}].mlp.layer_norm.bias": "decoder.stacked_transformer.layers[{i}].mlp.layer_norm.bias", + "stacked_transformer.layers[{i}].input_layernorm.weight": "decoder.stacked_transformer.layers[{i}].input_layernorm.weight", } for old_key, new_key in MODEL_LAYER_MAPPING.items(): diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 79a852ab8c01..49b35cd77890 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -91,30 +91,24 @@ def forward(self, x): return output + residual -class TimesFmRMSNorm(torch.nn.Module): - """Pax rms norm in pytorch.""" - - def __init__( - self, - dim: int, - eps: float = 1e-6, - add_unit_offset: bool = False, - ): +class TimesFmRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + TimesFmRMSNorm is equivalent to T5LayerNorm + """ super().__init__() - self.eps = eps - self.add_unit_offset = add_unit_offset - self.weight = nn.Parameter(torch.zeros(dim)) + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) - def forward(self, x): - output = self._norm(x.float()) - if self.add_unit_offset: - output = output * (1 + self.weight.float()) - else: - output = output * self.weight.float() - return output.type_as(x) + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class TimesFmPositionalEmbedding(nn.Module): @@ -904,7 +898,7 @@ def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> to def forward( self, inputs: Sequence[torch.Tensor], - freq: Optional[Sequence[Union[torch.Tensor,int]]] = None, + freq: Optional[Sequence[Union[torch.Tensor, int]]] = None, window_size: Optional[int] = None, future_target: Optional[torch.Tensor] = None, forecast_context_len: Optional[int] = None, @@ -1034,7 +1028,7 @@ def forward( ] else: # `full_outputs` indexing starts at the forecast horizon. - full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0:self.horizon_len, :] + full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0 : self.horizon_len, :] mean_outputs = full_outputs[:, :, 0] if window_size is not None: @@ -1067,3 +1061,9 @@ def forward( return_tuple.append(decoder_output.all_attentions) return_tuple += [mean_outputs, full_outputs, loss] return tuple(return_tuple) + + +__all__ = [ + "TimesFmModelForPrediction", + "TimesFmPreTrainedModel", +] From 7180f79f929c90c205126dc5be516c5a054388d9 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Feb 2025 20:43:45 +0100 Subject: [PATCH 070/242] fix tests --- docs/source/en/perf_infer_gpu_one.md | 2 +- src/transformers/models/timesfm/modeling_timesfm.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 49a661e6e129..7dfd0012a2c5 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -310,7 +310,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) * [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel) * [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model) -* [TimesFM](https://huggingface.co/docs/transformers/model_doc/timesfm#transformers.TimesFmModel) +* [TimesFM](https://huggingface.co/docs/transformers/model_doc/timesfm#transformers.TimesFmDecoder) * [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) * [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel) * [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 49b35cd77890..b03042faf0d9 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -1047,8 +1047,8 @@ def forward( if return_dict: return TimesFmOutputForPrediction( last_hidden_state=decoder_output.last_hidden_state, - attentions=decoder_output.all_attentions if output_attentions else None, - hidden_states=decoder_output.all_hidden_states if output_hidden_states else None, + attentions=decoder_output.attentions if output_attentions else None, + hidden_states=decoder_output.hidden_states if output_hidden_states else None, mean_predictions=mean_outputs, full_predictions=full_outputs, loss=loss, @@ -1056,9 +1056,9 @@ def forward( else: return_tuple = [decoder_output.last_hidden_state] if output_hidden_states: - return_tuple.append(decoder_output.all_hidden_states) + return_tuple.append(decoder_output.hidden_states) if output_attentions: - return_tuple.append(decoder_output.all_attentions) + return_tuple.append(decoder_output.attentions) return_tuple += [mean_outputs, full_outputs, loss] return tuple(return_tuple) @@ -1066,4 +1066,5 @@ def forward( __all__ = [ "TimesFmModelForPrediction", "TimesFmPreTrainedModel", + "TimesFmDecoder", ] From cdb4239a2783d3e7f1bf86bedab2267a3c7a28be Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Feb 2025 20:57:31 +0100 Subject: [PATCH 071/242] use initializer_range --- .../models/timesfm/configuration_timesfm.py | 9 ++++----- .../models/timesfm/modeling_timesfm.py | 18 +++++++++--------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 81ce310fe098..ea110df5a573 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -66,9 +66,8 @@ class TimesFmConfig(PretrainedConfig): The dropout probability for the attention scores. use_positional_embedding (`bool`, *optional*, defaults to `True`): Whether to add positional embeddings. - initializer_factor (`float`, *optional*, defaults to 1.0): - A factor for initializing all weight matrices (should be kept to 1, used internally for initialization - testing). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. """ model_type = "timesfm" @@ -97,7 +96,7 @@ def __init__( pad_val: float = 1123581321.0, attention_dropout: float = 0.0, use_positional_embedding: bool = True, - initializer_factor: float = 1.0, + initializer_range: float = 0.02, **kwargs, ): self.patch_len = patch_len @@ -115,7 +114,7 @@ def __init__( self.rms_norm_eps = rms_norm_eps self.attention_dropout = attention_dropout self.use_positional_embedding = use_positional_embedding - self.initializer_factor = initializer_factor + self.initializer_range = initializer_range super().__init__( is_encoder_decoder=self.is_encoder_decoder, diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index b03042faf0d9..c95ad08e71f9 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -617,10 +617,10 @@ class TimesFmPreTrainedModel(PreTrainedModel): def _init_weights(self, module): if isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0, std=self.config.initializer_factor) + module.weight.data.normal_(mean=0, std=self.config.initializer_range) elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0, std=self.config.initializer_factor) + module.weight.data.normal_(mean=0, std=self.config.initializer_range) if module.bias is not None: nn.init.zeros_(module.bias) @@ -633,12 +633,12 @@ def _init_weights(self, module): elif isinstance(module, TimesFmMLP): # Initialize gate projection - module.gate_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor) + module.gate_proj.weight.data.normal_(mean=0, std=self.config.initializer_range) if module.gate_proj.bias is not None: nn.init.zeros_(module.gate_proj.bias) # Initialize down projection - module.down_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor) + module.down_proj.weight.data.normal_(mean=0, std=self.config.initializer_range) if module.down_proj.bias is not None: nn.init.zeros_(module.down_proj.bias) @@ -648,12 +648,12 @@ def _init_weights(self, module): elif isinstance(module, TimesFmAttention): # Initialize qkv projection - module.qkv_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor) + module.qkv_proj.weight.data.normal_(mean=0, std=self.config.initializer_range) if module.qkv_proj.bias is not None: nn.init.zeros_(module.qkv_proj.bias) # Initialize output projection - module.o_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor) + module.o_proj.weight.data.normal_(mean=0, std=self.config.initializer_range) if module.o_proj.bias is not None: nn.init.zeros_(module.o_proj.bias) @@ -662,17 +662,17 @@ def _init_weights(self, module): elif isinstance(module, TimesFmResidualBlock): # Initialize hidden layer - module.hidden_layer[0].weight.data.normal_(mean=0, std=self.config.initializer_factor) + module.hidden_layer[0].weight.data.normal_(mean=0, std=self.config.initializer_range) if module.hidden_layer[0].bias is not None: nn.init.zeros_(module.hidden_layer[0].bias) # Initialize output layer - module.output_layer.weight.data.normal_(mean=0, std=self.config.initializer_factor) + module.output_layer.weight.data.normal_(mean=0, std=self.config.initializer_range) if module.output_layer.bias is not None: nn.init.zeros_(module.output_layer.bias) # Initialize residual layer - module.residual_layer.weight.data.normal_(mean=0, std=self.config.initializer_factor) + module.residual_layer.weight.data.normal_(mean=0, std=self.config.initializer_range) if module.residual_layer.bias is not None: nn.init.zeros_(module.residual_layer.bias) From c48d673db5e972fc9278c8cbd28287ab78b3be2c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Feb 2025 21:14:30 +0100 Subject: [PATCH 072/242] TimesFmModel instead of TimesFmDecoder --- docs/source/en/model_doc/timesfm.md | 4 ++-- src/transformers/__init__.py | 4 ++-- .../models/timesfm/configuration_timesfm.py | 2 +- src/transformers/models/timesfm/modeling_timesfm.py | 12 ++++++------ src/transformers/utils/dummy_pt_objects.py | 2 +- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/source/en/model_doc/timesfm.md b/docs/source/en/model_doc/timesfm.md index 144b29769faf..88366594e803 100644 --- a/docs/source/en/model_doc/timesfm.md +++ b/docs/source/en/model_doc/timesfm.md @@ -34,9 +34,9 @@ The original code can be found [here](https://github.com/google-research/timesfm [[autodoc]] TimesFmConfig -## TimesFmDecoder +## TimesFmModel -[[autodoc]] TimesFmDecoder +[[autodoc]] TimesFmModel - forward ## TimesFmModelForPrediction diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 35db9b81be48..33c8b7ef4ad8 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3713,7 +3713,7 @@ ) _import_structure["models.timesfm"].extend( [ - "TimesFmDecoder", + "TimesFmModel", "TimesFmModelForPrediction", "TimesFmPreTrainedModel", ] @@ -8401,7 +8401,7 @@ TimeSeriesTransformerPreTrainedModel, ) from .models.timesfm import ( - TimesFmDecoder, + TimesFmModel, TimesFmModelForPrediction, TimesFmPreTrainedModel, ) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index ea110df5a573..6f17b0c7bfcc 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -26,7 +26,7 @@ class TimesFmConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`TimesFmModelForPrediction`] or a [`TFTimesFmDecoder`]. It is used to + This is the configuration class to store the configuration of a [`TimesFmModelForPrediction`] or a [`TFTimesFmModel`]. It is used to instantiate a TimesFM model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the TimesFM [google/timesfm-1.0-200m](https://huggingface.co/google/timesfm-1.0-200m) architecture. diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index c95ad08e71f9..6d76a17ea646 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -29,7 +29,7 @@ @dataclass -class TimesFmDecoderOutput(BaseModelOutput): +class TimesFmOutput(BaseModelOutput): loc: Optional[torch.Tensor] = None scale: Optional[torch.Tensor] = None @@ -680,7 +680,7 @@ def _init_weights(self, module): pass -class TimesFmDecoder(TimesFmPreTrainedModel): +class TimesFmModel(TimesFmPreTrainedModel): """Patched time-series decoder without any specific output layer.""" def __init__(self, config: TimesFmConfig): @@ -766,7 +766,7 @@ def forward( output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, - ) -> Union[TimesFmDecoderOutput, tuple[torch.Tensor, ...]]: + ) -> Union[TimesFmOutput, tuple[torch.Tensor, ...]]: model_input, patched_padding, stats, _ = self._preprocess_input( input_ts=input_ts, input_padding=input_padding, @@ -786,7 +786,7 @@ def forward( all_hidden_states = None if return_dict: - return TimesFmDecoderOutput( + return TimesFmOutput( last_hidden_state=transformer_output.last_hidden_state, hidden_states=all_hidden_states, attentions=transformer_output.attentions if output_attentions else None, @@ -813,7 +813,7 @@ def __init__(self, config: TimesFmConfig): self.context_len = config.context_len self.horizon_len = config.horizon_len - self.decoder = TimesFmDecoder(config) + self.decoder = TimesFmModel(config) # quantile and mean output self.horizon_ff_layer = TimesFmResidualBlock( @@ -1066,5 +1066,5 @@ def forward( __all__ = [ "TimesFmModelForPrediction", "TimesFmPreTrainedModel", - "TimesFmDecoder", + "TimesFmModel", ] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 1c9f99da97a9..c1e591f07903 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -9410,7 +9410,7 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class TimesFmDecoder(metaclass=DummyObject): +class TimesFmModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): From ce5f2165033d8ce985864a263d4b0f6de23891e9 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Feb 2025 21:24:56 +0100 Subject: [PATCH 073/242] TimesFmPositionalEmbedding takes config for its init --- .../models/timesfm/configuration_timesfm.py | 10 ++++++++ .../models/timesfm/modeling_timesfm.py | 24 ++++--------------- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 6f17b0c7bfcc..d8eaabdc9bfa 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -68,6 +68,12 @@ class TimesFmConfig(PretrainedConfig): Whether to add positional embeddings. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + min_timescale (`int`, *optional*, defaults to 1): + The start of the geometric positional index. Determines the periodicity of + the added signal. + max_timescale (`int`, *optional*, defaults to 10_000): + The end of the geometric positional index. Determines the frequency of the + added signal. """ model_type = "timesfm" @@ -97,6 +103,8 @@ def __init__( attention_dropout: float = 0.0, use_positional_embedding: bool = True, initializer_range: float = 0.02, + min_timescale: int = 1, + max_timescale: int = 10_000, **kwargs, ): self.patch_len = patch_len @@ -115,6 +123,8 @@ def __init__( self.attention_dropout = attention_dropout self.use_positional_embedding = use_positional_embedding self.initializer_range = initializer_range + self.min_timescale = min_timescale + self.max_timescale = max_timescale super().__init__( is_encoder_decoder=self.is_encoder_decoder, diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 6d76a17ea646..f4070fa889df 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -113,25 +113,13 @@ def extra_repr(self): class TimesFmPositionalEmbedding(nn.Module): """Generates position embedding for a given 1-d sequence. - - Attributes: - embedding_dims: Dimension of the embedding to be generated. - min_timescale: Start of the geometric index. Determines the periodicity of - the added signal. Defaults to 1. - max_timescale: End of the geometric index. Determines the frequency of the - added signal. Defaults to 10_000. """ - def __init__( - self, - embedding_dims: int, - min_timescale: int = 1, - max_timescale: int = 10_000, - ) -> None: + def __init__(self, config: TimesFmConfig) -> None: super().__init__() - self.min_timescale = min_timescale - self.max_timescale = max_timescale - self.embedding_dims = embedding_dims + self.min_timescale = config.min_timescale + self.max_timescale = config.max_timescale + self.embedding_dims = config.model_dim def forward(self, seq_length=None, position=None): """Generates a Tensor of sinusoids with different frequencies. @@ -695,9 +683,7 @@ def __init__(self, config: TimesFmConfig): self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.model_dim) self.stacked_transformer = TimesFmStackedDecoder(config=config) if self.config.use_positional_embedding: - self.position_emb = TimesFmPositionalEmbedding( - embedding_dims=self.config.model_dim, - ) + self.position_emb = TimesFmPositionalEmbedding(config=config) # Initialize weights and apply final processing self.post_init() From 9453ed93eb3ae6fff6df8b26d6f1f35d1ac459e1 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Feb 2025 21:35:42 +0100 Subject: [PATCH 074/242] 2.0-500m-pytorch default configs --- .../models/timesfm/configuration_timesfm.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index d8eaabdc9bfa..c4577c9e599a 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -29,7 +29,7 @@ class TimesFmConfig(PretrainedConfig): This is the configuration class to store the configuration of a [`TimesFmModelForPrediction`] or a [`TFTimesFmModel`]. It is used to instantiate a TimesFM model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the TimesFM - [google/timesfm-1.0-200m](https://huggingface.co/google/timesfm-1.0-200m) architecture. + [google/timesfm-2.0-500m-pytorch](https://huggingface.co/google/timesfm-2.0-500m-pytorch) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -43,7 +43,7 @@ class TimesFmConfig(PretrainedConfig): The length of the prediction horizon. freq_size (`int`, *optional*, defaults to 3): The number of frequency embeddings. - num_layers (`int`, *optional*, defaults to 20): + num_layers (`int`, *optional*, defaults to 50): Number of Transformer layers. model_dim (`int`, *optional*, defaults to 1280): Size of the hidden layers in the feed-forward networks. @@ -64,7 +64,7 @@ class TimesFmConfig(PretrainedConfig): The value used to pad the predictions. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout probability for the attention scores. - use_positional_embedding (`bool`, *optional*, defaults to `True`): + use_positional_embedding (`bool`, *optional*, defaults to `False`): Whether to add positional embeddings. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. @@ -91,7 +91,7 @@ def __init__( context_len: int = 512, horizon_len: int = 128, freq_size: int = 3, - num_layers: int = 20, + num_layers: int = 50, model_dim: int = 1280, intermediate_size: int = 1280, head_dim: int = 80, @@ -101,7 +101,7 @@ def __init__( quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], pad_val: float = 1123581321.0, attention_dropout: float = 0.0, - use_positional_embedding: bool = True, + use_positional_embedding: bool = False, initializer_range: float = 0.02, min_timescale: int = 1, max_timescale: int = 10_000, From 61c96fd528d54df011f4a1b27f8537663d7b8322 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Feb 2025 21:42:57 +0100 Subject: [PATCH 075/242] use TimesFmModel --- docs/source/en/perf_infer_gpu_one.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 7dfd0012a2c5..49a661e6e129 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -310,7 +310,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) * [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel) * [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model) -* [TimesFM](https://huggingface.co/docs/transformers/model_doc/timesfm#transformers.TimesFmDecoder) +* [TimesFM](https://huggingface.co/docs/transformers/model_doc/timesfm#transformers.TimesFmModel) * [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) * [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel) * [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel) From 9538c1d3ad35e38575fd1db4a3e9b022f94748c5 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 16 Feb 2025 19:47:09 +0100 Subject: [PATCH 076/242] fix formatting --- src/transformers/models/timesfm/modeling_timesfm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index f4070fa889df..c2c8a13f3b6e 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -112,8 +112,7 @@ def extra_repr(self): class TimesFmPositionalEmbedding(nn.Module): - """Generates position embedding for a given 1-d sequence. - """ + """Generates position embedding for a given 1-d sequence.""" def __init__(self, config: TimesFmConfig) -> None: super().__init__() From bfa69e7bdb876d9dd8713316a3a597af89fd9b0e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 17 Feb 2025 08:59:43 +0100 Subject: [PATCH 077/242] ignore TimesFmModel for testing --- utils/check_repo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/utils/check_repo.py b/utils/check_repo.py index c91edc52cb49..197ef44e3b49 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -143,6 +143,7 @@ "MllamaVisionModel", # Building part of bigger (tested) model. # TODO: add tests "Emu3VQVAE", # Building part of bigger (tested) model "Emu3TextModel", # Building part of bigger (tested) model + "TimesFmModel", # Building part of bigger (tested) model ] ) From c34286fb13f82fec714780daa38d10a3c0da3d36 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 27 Feb 2025 11:57:20 +0100 Subject: [PATCH 078/242] fix docstring --- src/transformers/models/timesfm/configuration_timesfm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index c4577c9e599a..570d39c02221 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -71,7 +71,7 @@ class TimesFmConfig(PretrainedConfig): min_timescale (`int`, *optional*, defaults to 1): The start of the geometric positional index. Determines the periodicity of the added signal. - max_timescale (`int`, *optional*, defaults to 10_000): + max_timescale (`int`, *optional*, defaults to 10000): The end of the geometric positional index. Determines the frequency of the added signal. """ From e401b3353f4b394de6f8782c028c15e33191e779 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 27 Feb 2025 12:04:51 +0100 Subject: [PATCH 079/242] override generate as its not needed --- .../models/timesfm/modeling_timesfm.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index c2c8a13f3b6e..3e9cc01f8178 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -666,6 +666,32 @@ def _init_weights(self, module): elif isinstance(module, TimesFmPositionalEmbedding): pass + def generate(self, *args, **kwargs): + """ + This method is disabled for TimesFM models. TimesFM models are designed for time series forecasting and should be used + with the forward() method instead. For forecasting, use: + + ```python + # For basic forecasting: + outputs = model(input_ts=your_time_series, input_padding=your_padding, freq=your_frequency) + + # For prediction with quantiles: + outputs = model.forward( + inputs=your_time_series_list, + freq=your_frequencies, + window_size=optional_window_size, + future_target=optional_target, + forecast_context_len=optional_context_length + ) + ``` + + See the model's documentation for more details on the forward method parameters. + """ + raise NotImplementedError( + "The generate() method is not implemented for TimesFM models as they are designed for time series " + "forecasting. Please use the forward() method instead. See the docstring of this method for usage examples." + ) + class TimesFmModel(TimesFmPreTrainedModel): """Patched time-series decoder without any specific output layer.""" From 85446e3f323b6b9a67a970728c77dfc97a6e74b8 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 27 Feb 2025 12:28:06 +0100 Subject: [PATCH 080/242] add doc strings --- .../models/timesfm/modeling_timesfm.py | 100 +++++++++++++----- 1 file changed, 75 insertions(+), 25 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 3e9cc01f8178..bf66c4d4f5f7 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -25,9 +25,15 @@ from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from .configuration_timesfm import TimesFmConfig +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "TimesFmConfig" + + @dataclass class TimesFmOutput(BaseModelOutput): loc: Optional[torch.Tensor] = None @@ -594,6 +600,27 @@ def expand_t(key_mask): return torch.minimum(a, b) # Element-wise minimum, similar to jnp.minimum +TIMESFM_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`TimesFmConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare TimesFM Model outputting raw hidden-states without any specific head on top.", + TIMESFM_START_DOCSTRING, +) class TimesFmPreTrainedModel(PreTrainedModel): """handles the loading for all models.""" @@ -693,6 +720,27 @@ def generate(self, *args, **kwargs): ) +TIMESFM_INPUTS_DOCSTRING = r""" + Args: + inputs: list of time series forecast contexts. Each context time series + should be a torch Tensor of potentially different context lengths. + freq: frequency of each context time series in the inputs. 0 for high frequency + (default), 1 for medium, and 2 for low. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare TimesFM Model outputting raw hidden-states without any specific head on top.", + TIMESFM_START_DOCSTRING, +) class TimesFmModel(TimesFmPreTrainedModel): """Patched time-series decoder without any specific output layer.""" @@ -769,17 +817,23 @@ def _preprocess_input( return model_input, patched_padding, stats, patched_inputs + @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) def forward( self, - input_ts: torch.Tensor, + inputs: torch.Tensor, input_padding: torch.LongTensor, freq: torch.Tensor, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ) -> Union[TimesFmOutput, tuple[torch.Tensor, ...]]: + """ + input_padding (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The padding indicator of the time series. + """ + model_input, patched_padding, stats, _ = self._preprocess_input( - input_ts=input_ts, + input_ts=inputs, input_padding=input_padding, ) f_emb = self.freq_emb(freq) # B x 1 x D @@ -906,6 +960,8 @@ def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> to losses.append(loss.mean()) return torch.stack(losses).mean() + @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TimesFmOutputForPrediction, config_class=_CONFIG_FOR_DOC) def forward( self, inputs: Sequence[torch.Tensor], @@ -919,31 +975,25 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[TimesFmOutputForPrediction, tuple[torch.Tensor, ...]]: - """Forecasts on a list of time series. - - Args: - inputs: list of time series forecast contexts. Each context time series - should be a torch Tensor of potentially different context lengths. - freq: frequency of each context time series in the inputs. 0 for high frequency - (default), 1 for medium, and 2 for low. - window_size: window size of trend + residual decomposition. If None then - we do not do decomposition. - future_target: optional future target time series to be used for loss computation. - forecast_context_len: optional max context length. - return_forecast_on_context: True to return the forecast on the context - when available, i.e. after the first input patch. - truncate_negative: truncate to only non-negative values if all the contexts - have non-negative values. - output_attentions: Whether to return the attentions. - output_hidden_states: Whether to return the hidden states. - return_dict: Whether to return a TimesFmOutputForPrediction object. + r""" + window_size (`int`, *optional*): + Window size of trend + residual decomposition. If None thenwe do not do decomposition. + future_target (`torch.Tensor`, *optional*): + Optional future target time series to be used for loss computation. + forecast_context_len (`int`, *optional*): + Optional max context length. + return_forecast_on_context (`bool`, *optional*): + True to return the forecast on the context when available, i.e. after the first input patch. + truncate_negative (`bool`, *optional*): + Truncate to only non-negative values if all the contexts have non-negative values. + have non-ne ative values. Returns: - A TimesFmOutputForPrediction object containing: - - the mean forecast of size (# inputs, # forecast horizon), - - the full forecast (mean + quantiles) of size - (# inputs, # forecast horizon, 1 + # quantiles). - - loss: the mean squared error loss + quantile loss if future_target is provided. + A TimesFmOutputForPrediction object or a tuple containing: + - the mean forecast of size (# inputs, # forecast horizon), + - the full forecast (mean + quantiles) of size + (# inputs, # forecast horizon, 1 + # quantiles). + - loss: the mean squared error loss + quantile loss if future_target is provided. """ if return_dict is None: return_dict = self.config.use_return_dict From c410cde64af398e012e56652905f21822f3d7df8 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 27 Feb 2025 12:49:28 +0100 Subject: [PATCH 081/242] fix logging --- src/transformers/models/timesfm/modeling_timesfm.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index bf66c4d4f5f7..efe417752534 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -14,7 +14,6 @@ # limitations under the License. """PyTorch TimesFM model.""" -import logging import math from dataclasses import dataclass from typing import List, Optional, Sequence, Tuple, Union @@ -25,7 +24,7 @@ from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_timesfm import TimesFmConfig @@ -1023,7 +1022,7 @@ def forward( freq = new_freqs if freq is None: - logging.info("No frequency provided via `freq`. Default to high (0).") + logger.info("No frequency provided via `freq`. Default to high (0).") freq = [0] * len(inputs) if output_attentions is None: @@ -1124,8 +1123,4 @@ def forward( return tuple(return_tuple) -__all__ = [ - "TimesFmModelForPrediction", - "TimesFmPreTrainedModel", - "TimesFmModel", -] +__all__ = ["TimesFmModelForPrediction", "TimesFmPreTrainedModel", "TimesFmModel"] From 8d5a210f3a619c74fd7ee932ac690a91c50ec8f3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 27 Feb 2025 13:10:37 +0100 Subject: [PATCH 082/242] add docstrings to output data classes --- .../models/timesfm/modeling_timesfm.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index efe417752534..f7211b7a847c 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -35,12 +35,30 @@ @dataclass class TimesFmOutput(BaseModelOutput): + """ + Args: + loc (`torch.Tensor` of shape `(batch_size, )`): + The mean of the time series inputs. + scale (`torch.Tensor` of shape `(batch_size,)`): + The scale of the time series inputs. + """ + loc: Optional[torch.Tensor] = None scale: Optional[torch.Tensor] = None @dataclass class TimesFmOutputForPrediction(BaseModelOutput): + """ + Args: + mean_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`): + The mean predictions of the time series. + full_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`): + The full predictions of the time series including the mean and the quantiles. + loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `future_target` is provided): + The loss of the TimesFM model. + """ + mean_predictions: Optional[torch.Tensor] = None full_predictions: Optional[torch.Tensor] = None loss: Optional[Union[torch.Tensor, float]] = None From c2625e02e2ff8e9461d8fa7e4b8a38a269a59dd4 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 27 Aug 2024 22:42:30 +0200 Subject: [PATCH 083/242] initial copy from t5 --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/timesfm.md | 71 + src/transformers/__init__.py | 24 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 8 + .../models/auto/tokenization_auto.py | 6 + src/transformers/models/timesfm/__init__.py | 66 + .../models/timesfm/configuration_timesfm.py | 163 ++ ...mesfm_original_tf_checkpoint_to_pytorch.py | 59 + .../convert_timesfmx_checkpoint_to_flax.py | 235 ++ .../convert_timesfmx_checkpoint_to_pytorch.py | 238 ++ .../models/timesfm/modeling_timesfm.py | 2388 +++++++++++++++++ tests/models/timesfm/__init__.py | 0 tests/models/timesfm/test_modeling_timesfm.py | 1459 ++++++++++ 15 files changed, 4722 insertions(+) create mode 100644 docs/source/en/model_doc/timesfm.md create mode 100644 src/transformers/models/timesfm/__init__.py create mode 100644 src/transformers/models/timesfm/configuration_timesfm.py create mode 100644 src/transformers/models/timesfm/convert_timesfm_original_tf_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_flax.py create mode 100644 src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/timesfm/modeling_timesfm.py create mode 100644 tests/models/timesfm/__init__.py create mode 100644 tests/models/timesfm/test_modeling_timesfm.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 7d7201da5027..330ac4a83b03 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -615,6 +615,8 @@ title: T5v1.1 - local: model_doc/tapex title: TAPEX + - local: model_doc/timesfm + title: TimesFM - local: model_doc/transfo-xl title: Transformer XL - local: model_doc/ul2 diff --git a/docs/source/en/model_doc/timesfm.md b/docs/source/en/model_doc/timesfm.md new file mode 100644 index 000000000000..9acc824f9e0f --- /dev/null +++ b/docs/source/en/model_doc/timesfm.md @@ -0,0 +1,71 @@ + + +# TimesFM + +## Overview + +The TimesFM model was proposed in []() by . + + +The abstract from the paper is the following: + +** + +Tips: + + + +This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). +The original code can be found [here](). + + +## TimesFMConfig + +[[autodoc]] TimesFMConfig + +## TimesFMModel + +[[autodoc]] TimesFMModel + - forward + +## TimesFMForConditionalGeneration + +[[autodoc]] TimesFMForConditionalGeneration + - forward + +## TimesFMEncoderModel + +[[autodoc]] TimesFMEncoderModel + - forward + +## TimesFMForSequenceClassification + +[[autodoc]] TimesFMForSequenceClassification + - forward + +## TimesFMForTokenClassification + +[[autodoc]] TimesFMForTokenClassification + - forward + +## TimesFMForQuestionAnswering + +[[autodoc]] TimesFMForQuestionAnswering + - forward + + + diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ed2682901008..6a09848414a8 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -813,6 +813,7 @@ "models.swinv2": ["Swinv2Config"], "models.switch_transformers": ["SwitchTransformersConfig"], "models.t5": ["T5Config"], + "models.timesfm": ["TimesFMConfig"], "models.table_transformer": ["TableTransformerConfig"], "models.tapas": [ "TapasConfig", @@ -3707,6 +3708,18 @@ "load_tf_weights_in_t5", ] ) + _import_structure["models.timesfm"].extend( + [ + "TimesFMEncoderModel", + "TimesFMForConditionalGeneration", + "TimesFMForQuestionAnswering", + "TimesFMForSequenceClassification", + "TimesFMForTokenClassification", + "TimesFMModel", + "TimesFMPreTrainedModel", + "load_tf_weights_in_timesfm", + ] + ) _import_structure["models.table_transformer"].extend( [ "TableTransformerForObjectDetection", @@ -6000,6 +6013,7 @@ SwitchTransformersConfig, ) from .models.t5 import T5Config + from .models.timesfm import TimesFMConfig from .models.table_transformer import ( TableTransformerConfig, ) @@ -8421,6 +8435,16 @@ T5PreTrainedModel, load_tf_weights_in_t5, ) + from .models.timesfm import ( + TimesFMEncoderModel, + TimesFMForConditionalGeneration, + TimesFMForQuestionAnswering, + TimesFMForSequenceClassification, + TimesFMForTokenClassification, + TimesFMModel, + TimesFMPreTrainedModel, + load_tf_weights_in_timesfm, + ) from .models.table_transformer import ( TableTransformerForObjectDetection, TableTransformerModel, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 74dad4a2418b..12957e1622bd 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -262,6 +262,7 @@ swinv2, switch_transformers, t5, + timesfm, table_transformer, tapas, textnet, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 8b2b514496d8..e906ed233747 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -291,6 +291,7 @@ ("swinv2", "Swinv2Config"), ("switch_transformers", "SwitchTransformersConfig"), ("t5", "T5Config"), + ("timesfm", "TimesFMConfig"), ("table-transformer", "TableTransformerConfig"), ("tapas", "TapasConfig"), ("textnet", "TextNetConfig"), @@ -639,6 +640,7 @@ ("swinv2", "Swin Transformer V2"), ("switch_transformers", "SwitchTransformers"), ("t5", "T5"), + ("timesfm", "TimesFM"), ("t5v1.1", "T5v1.1"), ("table-transformer", "Table Transformer"), ("tapas", "TAPAS"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index cf6518c41760..0fe7488a4605 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -267,6 +267,7 @@ ("swinv2", "Swinv2Model"), ("switch_transformers", "SwitchTransformersModel"), ("t5", "T5Model"), + ("timesfm", "TimesFMModel"), ("table-transformer", "TableTransformerModel"), ("tapas", "TapasModel"), ("textnet", "TextNetModel"), @@ -379,6 +380,7 @@ ("squeezebert", "SqueezeBertForMaskedLM"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), + ("timesfm", "TimesFMForConditionalGeneration"), ("tapas", "TapasForMaskedLM"), ("transfo-xl", "TransfoXLLMHeadModel"), ("tvlt", "TvltForPreTraining"), @@ -473,6 +475,7 @@ ("squeezebert", "SqueezeBertForMaskedLM"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), + ("timesfm", "TimesFMForConditionalGeneration"), ("tapas", "TapasForMaskedLM"), ("transfo-xl", "TransfoXLLMHeadModel"), ("wav2vec2", "Wav2Vec2ForMaskedLM"), @@ -960,6 +963,7 @@ ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToText"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), + ("timesfm", "TimesFMForConditionalGeneration"), ("umt5", "UMT5ForConditionalGeneration"), ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"), ] @@ -1065,6 +1069,7 @@ ("stablelm", "StableLmForSequenceClassification"), ("starcoder2", "Starcoder2ForSequenceClassification"), ("t5", "T5ForSequenceClassification"), + ("timesfm", "TimesFMForSequenceClassification"), ("tapas", "TapasForSequenceClassification"), ("transfo-xl", "TransfoXLForSequenceClassification"), ("umt5", "UMT5ForSequenceClassification"), @@ -1144,6 +1149,7 @@ ("splinter", "SplinterForQuestionAnswering"), ("squeezebert", "SqueezeBertForQuestionAnswering"), ("t5", "T5ForQuestionAnswering"), + ("timesfm", "TimesFMForQuestionAnswering"), ("umt5", "UMT5ForQuestionAnswering"), ("xlm", "XLMForQuestionAnsweringSimple"), ("xlm-roberta", "XLMRobertaForQuestionAnswering"), @@ -1248,6 +1254,7 @@ ("stablelm", "StableLmForTokenClassification"), ("starcoder2", "Starcoder2ForTokenClassification"), ("t5", "T5ForTokenClassification"), + ("timesfm", "TimesFMForTokenClassification"), ("umt5", "UMT5ForTokenClassification"), ("xlm", "XLMForTokenClassification"), ("xlm-roberta", "XLMRobertaForTokenClassification"), @@ -1476,6 +1483,7 @@ ("roformer", "RoFormerModel"), ("squeezebert", "SqueezeBertModel"), ("t5", "T5EncoderModel"), + ("timesfm", "TimesFMEncoderModel"), ("umt5", "UMT5EncoderModel"), ("xlm", "XLMModel"), ("xlm-roberta", "XLMRobertaModel"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 61c2c2e23d2f..d77d40720795 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -510,6 +510,12 @@ "T5TokenizerFast" if is_tokenizers_available() else None, ), ), + ( + "timesfm", + ( + "T5Tokenizer" if is_sentencepiece_available() else None, + "T5TokenizerFast" if is_tokenizers_available() else None, + ), ("tapas", ("TapasTokenizer", None)), ("tapex", ("TapexTokenizer", None)), ("transfo-xl", ("TransfoXLTokenizer", None)), diff --git a/src/transformers/models/timesfm/__init__.py b/src/transformers/models/timesfm/__init__.py new file mode 100644 index 000000000000..7398dccbda88 --- /dev/null +++ b/src/transformers/models/timesfm/__init__.py @@ -0,0 +1,66 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, +) + + +_import_structure = {"configuration_timesfm": ["TimesFMConfig", "TimesFMOnnxConfig"]} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_timesfm"] = [ + "TimesFMEncoderModel", + "TimesFMForConditionalGeneration", + "TimesFMModel", + "TimesFMPreTrainedModel", + "load_tf_weights_in_timesfm", + "TimesFMForQuestionAnswering", + "TimesFMForSequenceClassification", + "TimesFMForTokenClassification", + ] + +if TYPE_CHECKING: + from .configuration_timesfm import TimesFMConfig, TimesFMOnnxConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_timesfm import ( + TimesFMEncoderModel, + TimesFMForConditionalGeneration, + TimesFMForQuestionAnswering, + TimesFMForSequenceClassification, + TimesFMForTokenClassification, + TimesFMModel, + TimesFMPreTrainedModel, + load_tf_weights_in_timesfm, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py new file mode 100644 index 000000000000..065f779b557c --- /dev/null +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -0,0 +1,163 @@ +# coding=utf-8 +# Copyright 2020, The TimesFM Authors and HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TimesFM model configuration""" + +from typing import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxSeq2SeqConfigWithPast +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class TimesFMConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`TimesFMModel`] or a [`TFTimesFMModel`]. It is used to + instantiate a TimesFM model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the TimesFM + [google/timesfm-1.0-200m](https://huggingface.co/google/timesfm-1.0-200m) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Arguments: + vocab_size (`int`, *optional*, defaults to 32128): + Vocabulary size of the TimesFM model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`TimesFMModel`] or [`TFTimesFMModel`]. + d_model (`int`, *optional*, defaults to 512): + Size of the encoder layers and the pooler layer. + d_kv (`int`, *optional*, defaults to 64): + Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will + be defined as `num_heads * d_kv`. + d_ff (`int`, *optional*, defaults to 2048): + Size of the intermediate feed forward layer in each `TimesFMBlock`. + num_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer encoder. + num_decoder_layers (`int`, *optional*): + Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set. + num_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + relative_attention_num_buckets (`int`, *optional*, defaults to 32): + The number of buckets to use for each attention layer. + relative_attention_max_distance (`int`, *optional*, defaults to 128): + The maximum distance of the longer sequences for the bucket separation. + dropout_rate (`float`, *optional*, defaults to 0.1): + The ratio for all dropout layers. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier. + layer_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the layer normalization layers. + initializer_factor (`float`, *optional*, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + feed_forward_proj (`string`, *optional*, defaults to `"relu"`): + Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. TimesFMv1.1 uses the + `"gated-gelu"` feed forward projection. Original TimesFM uses `"relu"`. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + """ + + model_type = "timesfm" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + + def __init__( + self, + vocab_size=32128, + d_model=512, + d_kv=64, + d_ff=2048, + num_layers=6, + num_decoder_layers=None, + num_heads=8, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + dropout_rate=0.1, + layer_norm_epsilon=1e-6, + initializer_factor=1.0, + feed_forward_proj="relu", + is_encoder_decoder=True, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + classifier_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.d_model = d_model + self.d_kv = d_kv + self.d_ff = d_ff + self.num_layers = num_layers + self.num_decoder_layers = ( + num_decoder_layers if num_decoder_layers is not None else self.num_layers + ) # default = symmetry + self.num_heads = num_heads + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.dropout_rate = dropout_rate + self.classifier_dropout = classifier_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_factor = initializer_factor + self.feed_forward_proj = feed_forward_proj + self.use_cache = use_cache + + act_info = self.feed_forward_proj.split("-") + self.dense_act_fn = act_info[-1] + self.is_gated_act = act_info[0] == "gated" + + if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2: + raise ValueError( + f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. " + "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. " + "'gated-gelu' or 'relu'" + ) + + # for backwards compatibility + if feed_forward_proj == "gated-gelu": + self.dense_act_fn = "gelu_new" + + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + + +class TimesFMOnnxConfig(OnnxSeq2SeqConfigWithPast): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = { + "input_ids": {0: "batch", 1: "encoder_sequence"}, + "attention_mask": {0: "batch", 1: "encoder_sequence"}, + } + if self.use_past: + common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence" + common_inputs["decoder_input_ids"] = {0: "batch"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} + else: + common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} + + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + + return common_inputs + + @property + def default_onnx_opset(self) -> int: + return 13 diff --git a/src/transformers/models/timesfm/convert_timesfm_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/timesfm/convert_timesfm_original_tf_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..aa66a8392d4f --- /dev/null +++ b/src/transformers/models/timesfm/convert_timesfm_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,59 @@ +# coding=utf-8 +# Copyright 2024 The TimesFM authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert TimesFM checkpoint.""" + +import argparse + +from transformers import TimesFMConfig, TimesFMForConditionalGeneration, load_tf_weights_in_timesfm +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): + # Initialise PyTorch model + config = TimesFMConfig.from_json_file(config_file) + print(f"Building PyTorch model from configuration: {config}") + model = TimesFMForConditionalGeneration(config) + + # Load weights from tf checkpoint + load_tf_weights_in_timesfm(model, config, tf_checkpoint_path) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + model.save_pretrained(pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained TimesFM model. \nThis specifies the model architecture." + ), + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_flax.py b/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_flax.py new file mode 100644 index 000000000000..98570e22876e --- /dev/null +++ b/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_flax.py @@ -0,0 +1,235 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert TimesFMX checkpoints from the original repository to JAX/FLAX model.""" + +import argparse + +from timesfmx import checkpoints + +from transformers import FlaxTimesFMForConditionalGeneration, TimesFMConfig + + +def convert_timesfmx_checkpoint_to_flax(timesfmx_checkpoint_path, config_name, flax_dump_folder_path): + config = TimesFMConfig.from_pretrained(config_name) + flax_model = FlaxTimesFMForConditionalGeneration(config=config) + timesfmx_model = checkpoints.load_timesfmx_checkpoint(timesfmx_checkpoint_path) + + split_mlp_wi = "wi_0" in timesfmx_model["target"]["encoder"]["layers_0"]["mlp"] + + # Encoder + for layer_index in range(config.num_layers): + layer_name = f"layers_{str(layer_index)}" + + # Self-Attention + timesfmx_attention_key = timesfmx_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"] + timesfmx_attention_out = timesfmx_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"] + timesfmx_attention_query = timesfmx_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"] + timesfmx_attention_value = timesfmx_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"] + + # Layer Normalization + timesfmx_attention_layer_norm = timesfmx_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"] + + if split_mlp_wi: + timesfmx_mlp_wi_0 = timesfmx_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"] + timesfmx_mlp_wi_1 = timesfmx_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"] + else: + timesfmx_mlp_wi = timesfmx_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"] + + timesfmx_mlp_wo = timesfmx_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"] + + # Layer Normalization + timesfmx_mlp_layer_norm = timesfmx_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"] + + # Assigning + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = ( + timesfmx_attention_key + ) + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = ( + timesfmx_attention_out + ) + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = ( + timesfmx_attention_query + ) + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = ( + timesfmx_attention_value + ) + + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = ( + timesfmx_attention_layer_norm + ) + + if split_mlp_wi: + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_0"][ + "kernel" + ] = timesfmx_mlp_wi_0 + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_1"][ + "kernel" + ] = timesfmx_mlp_wi_1 + else: + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi"]["kernel"] = ( + timesfmx_mlp_wi + ) + + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wo"]["kernel"] = ( + timesfmx_mlp_wo + ) + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = ( + timesfmx_mlp_layer_norm + ) + + # Only for layer 0: + timesfmx_encoder_rel_embedding = timesfmx_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T + flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][ + "embedding" + ] = timesfmx_encoder_rel_embedding + + # Assigning + timesfmx_encoder_norm = timesfmx_model["target"]["encoder"]["encoder_norm"]["scale"] + flax_model.params["encoder"]["final_layer_norm"]["weight"] = timesfmx_encoder_norm + + # Decoder + for layer_index in range(config.num_decoder_layers): + layer_name = f"layers_{str(layer_index)}" + + # Self-Attention + timesfmx_attention_key = timesfmx_model["target"]["decoder"][layer_name]["self_attention"]["key"]["kernel"] + timesfmx_attention_out = timesfmx_model["target"]["decoder"][layer_name]["self_attention"]["out"]["kernel"] + timesfmx_attention_query = timesfmx_model["target"]["decoder"][layer_name]["self_attention"]["query"]["kernel"] + timesfmx_attention_value = timesfmx_model["target"]["decoder"][layer_name]["self_attention"]["value"]["kernel"] + + # Layer Normalization + timesfmx_pre_attention_layer_norm = timesfmx_model["target"]["decoder"][layer_name]["pre_self_attention_layer_norm"][ + "scale" + ] + + # Encoder-Decoder-Attention + timesfmx_enc_dec_attention_key = timesfmx_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["key"][ + "kernel" + ] + timesfmx_enc_dec_attention_out = timesfmx_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["out"][ + "kernel" + ] + timesfmx_enc_dec_attention_query = timesfmx_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["query"][ + "kernel" + ] + timesfmx_enc_dec_attention_value = timesfmx_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["value"][ + "kernel" + ] + + # Layer Normalization + timesfmx_cross_layer_norm = timesfmx_model["target"]["decoder"][layer_name]["pre_cross_attention_layer_norm"]["scale"] + + # MLP + if split_mlp_wi: + timesfmx_mlp_wi_0 = timesfmx_model["target"]["decoder"][layer_name]["mlp"]["wi_0"]["kernel"] + timesfmx_mlp_wi_1 = timesfmx_model["target"]["decoder"][layer_name]["mlp"]["wi_1"]["kernel"] + else: + timesfmx_mlp_wi = timesfmx_model["target"]["decoder"][layer_name]["mlp"]["wi"]["kernel"] + + timesfmx_mlp_wo = timesfmx_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"] + + # Layer Normalization + tx5_mlp_layer_norm = timesfmx_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"] + + # Assigning + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = ( + timesfmx_attention_key + ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = ( + timesfmx_attention_out + ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = ( + timesfmx_attention_query + ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = ( + timesfmx_attention_value + ) + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = ( + timesfmx_pre_attention_layer_norm + ) + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["k"]["kernel"] = ( + timesfmx_enc_dec_attention_key + ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["o"]["kernel"] = ( + timesfmx_enc_dec_attention_out + ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["q"]["kernel"] = ( + timesfmx_enc_dec_attention_query + ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["v"]["kernel"] = ( + timesfmx_enc_dec_attention_value + ) + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = ( + timesfmx_cross_layer_norm + ) + + if split_mlp_wi: + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_0"][ + "kernel" + ] = timesfmx_mlp_wi_0 + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_1"][ + "kernel" + ] = timesfmx_mlp_wi_1 + else: + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi"]["kernel"] = ( + timesfmx_mlp_wi + ) + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wo"]["kernel"] = ( + timesfmx_mlp_wo + ) + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["layer_norm"]["weight"] = ( + tx5_mlp_layer_norm + ) + + # Decoder Normalization + tx5_decoder_norm = timesfmx_model["target"]["decoder"]["decoder_norm"]["scale"] + flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm + + # Only for layer 0: + timesfmx_decoder_rel_embedding = timesfmx_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T + flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][ + "embedding" + ] = timesfmx_decoder_rel_embedding + + # Token Embeddings + tx5_token_embeddings = timesfmx_model["target"]["token_embedder"]["embedding"] + flax_model.params["shared"]["embedding"] = tx5_token_embeddings + + # LM Head (only in v1.1 checkpoints) + if "logits_dense" in timesfmx_model["target"]["decoder"]: + flax_model.params["lm_head"]["kernel"] = timesfmx_model["target"]["decoder"]["logits_dense"]["kernel"] + + flax_model.save_pretrained(flax_dump_folder_path) + print("TimesFMX Model was sucessfully converted!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--timesfmx_checkpoint_path", default=None, type=str, required=True, help="Path the TX5 checkpoint." + ) + parser.add_argument("--config_name", default=None, type=str, required=True, help="Config name of TimesFM model.") + parser.add_argument( + "--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model." + ) + args = parser.parse_args() + convert_timesfmx_checkpoint_to_flax(args.timesfmx_checkpoint_path, args.config_name, args.flax_dump_folder_path) diff --git a/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_pytorch.py b/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..b761d76bbdcd --- /dev/null +++ b/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_pytorch.py @@ -0,0 +1,238 @@ +# coding=utf-8 +# Copyright 2024 Google LLC and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Convert TimesFMX checkpoint to PyTorch + +Steps: +- Install gsutil according to https://cloud.google.com/storage/docs/gsutil_install +- Get a TimesFMX checkpoint at https://github.com/google-research/timesfmx/blob/main/docs/models.md#timesfm-11-checkpoints Example: + `gsutil -m cp -r gs://timesfm-data/pretrained_models/timesfmx/timesfm_1_1_small $HOME/` +- Create or download a corresponding config for the downloaded model. E.g. for TimesFM v1.1 small, you can use + https://huggingface.co/google/timesfm-v1_1-small/blob/main/config.json +- Convert: + ``` + python3 convert_timesfmx_checkpoint_to_pytorch.py --timesfmx_checkpoint_path=$HOME/timesfm_1_1_small --config_file=config.json\ + --pytorch_dump_path=$HOME/timesfm_1_1_small_pt + ``` +""" + +import argparse +import collections + +import torch +from flax import traverse_util +from timesfmx import checkpoints + +from transformers import TimesFMConfig, TimesFMEncoderModel, TimesFMForConditionalGeneration +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def timesfmx_attention_lookup(params, i, prefix, layer_name="attention"): + """Returns the KOQV parameters of (self-)attention. Does not transpose.""" + k = params[f"{prefix}/layers_{i}/{layer_name}/key/kernel"] + o = params[f"{prefix}/layers_{i}/{layer_name}/out/kernel"] + q = params[f"{prefix}/layers_{i}/{layer_name}/query/kernel"] + v = params[f"{prefix}/layers_{i}/{layer_name}/value/kernel"] + return k, o, q, v + + +def timesfmx_mlp_lookup(params, i, prefix, split_mlp_wi=False): + """Returns the MLP parameters of a layer. Does not transpose.""" + if split_mlp_wi: + wi_0 = params[f"{prefix}/layers_{i}/mlp/wi_0/kernel"] + wi_1 = params[f"{prefix}/layers_{i}/mlp/wi_1/kernel"] + wi = (wi_0, wi_1) + else: + wi = params[f"{prefix}/layers_{i}/mlp/wi/kernel"] + + wo = params[f"{prefix}/layers_{i}/mlp/wo/kernel"] + return wi, wo + + +def timesfmx_layer_norm_lookup(params, i, prefix, layer_name): + """Returns the layer norm param of a layer.""" + return params[f"{prefix}/layers_{i}/{layer_name}/scale"] + + +def convert_timesfmx_to_pytorch(variables: dict, *, num_layers: int, num_decoder_layers: int, is_encoder_only: bool): + """Converts the parameters from TimesFMX-Flax to Transformers-PyTorch.""" + old = traverse_util.flatten_dict(variables["target"]) + old = {"/".join(k): v for k, v in old.items()} + + # v1.1 models have a gated GeLU with wi_0 and wi_1 instead of wi + split_mlp_wi = "encoder/layers_0/mlp/wi_0/kernel" in old + print("Split MLP:", split_mlp_wi) + + new = collections.OrderedDict() + + # Shared embeddings. + new["shared.weight"] = old["token_embedder/embedding"] + + # Encoder. + for i in range(num_layers): + # Block i, layer 0 (Self Attention). + layer_norm = timesfmx_layer_norm_lookup(old, i, "encoder", "pre_attention_layer_norm") + k, o, q, v = timesfmx_attention_lookup(old, i, "encoder", "attention") + new[f"encoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm + new[f"encoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T + new[f"encoder.block.{i}.layer.0.SelfAttention.o.weight"] = o.T + new[f"encoder.block.{i}.layer.0.SelfAttention.q.weight"] = q.T + new[f"encoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T + + # Block i, layer 1 (MLP). + layer_norm = timesfmx_layer_norm_lookup(old, i, "encoder", "pre_mlp_layer_norm") + wi, wo = timesfmx_mlp_lookup(old, i, "encoder", split_mlp_wi) + new[f"encoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm + if split_mlp_wi: + new[f"encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight"] = wi[0].T + new[f"encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight"] = wi[1].T + else: + new[f"encoder.block.{i}.layer.1.DenseReluDense.wi.weight"] = wi.T + new[f"encoder.block.{i}.layer.1.DenseReluDense.wo.weight"] = wo.T + + new["encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = old[ + "encoder/relpos_bias/rel_embedding" + ].T + new["encoder.final_layer_norm.weight"] = old["encoder/encoder_norm/scale"] + + if not is_encoder_only: + # Decoder. + for i in range(num_decoder_layers): + # Block i, layer 0 (Self Attention). + layer_norm = timesfmx_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm") + k, o, q, v = timesfmx_attention_lookup(old, i, "decoder", "self_attention") + new[f"decoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm + new[f"decoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T + new[f"decoder.block.{i}.layer.0.SelfAttention.o.weight"] = o.T + new[f"decoder.block.{i}.layer.0.SelfAttention.q.weight"] = q.T + new[f"decoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T + + # Block i, layer 1 (Cross Attention). + layer_norm = timesfmx_layer_norm_lookup(old, i, "decoder", "pre_cross_attention_layer_norm") + k, o, q, v = timesfmx_attention_lookup(old, i, "decoder", "encoder_decoder_attention") + new[f"decoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm + new[f"decoder.block.{i}.layer.1.EncDecAttention.k.weight"] = k.T + new[f"decoder.block.{i}.layer.1.EncDecAttention.o.weight"] = o.T + new[f"decoder.block.{i}.layer.1.EncDecAttention.q.weight"] = q.T + new[f"decoder.block.{i}.layer.1.EncDecAttention.v.weight"] = v.T + + # Block i, layer 2 (MLP). + layer_norm = timesfmx_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm") + wi, wo = timesfmx_mlp_lookup(old, i, "decoder", split_mlp_wi) + new[f"decoder.block.{i}.layer.2.layer_norm.weight"] = layer_norm + if split_mlp_wi: + new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight"] = wi[0].T + new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight"] = wi[1].T + else: + new[f"decoder.block.{i}.layer.2.DenseReluDense.wi.weight"] = wi.T + new[f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight"] = wo.T + + new["decoder.final_layer_norm.weight"] = old["decoder/decoder_norm/scale"] + new["decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = old[ + "decoder/relpos_bias/rel_embedding" + ].T + + # LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead) + if "decoder/logits_dense/kernel" in old: + new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T + + return new + + +def make_state_dict(converted_params, is_encoder_only: bool): + """Prepares a state dict for the PyTorch model.""" + # Make a state dict with torch tensors. + state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()]) + + # Add what is missing. + if "encoder.embed_tokens.weight" not in state_dict: + state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"] + + if not is_encoder_only: + if "decoder.embed_tokens.weight" not in state_dict: + state_dict["decoder.embed_tokens.weight"] = state_dict["shared.weight"] + + if "lm_head.weight" not in state_dict: # For old 1.0 models. + print("Using shared word embeddings as lm_head.") + state_dict["lm_head.weight"] = state_dict["shared.weight"] + + return state_dict + + +def load_timesfmx_weights_in_timesfm(model, config, timesfmx_checkpoint_path, is_encoder_only): + """Replaces the params in model witht the TimesFMX converted params.""" + variables = checkpoints.load_timesfmx_checkpoint(timesfmx_checkpoint_path) + converted = convert_timesfmx_to_pytorch( + variables, + num_layers=config.num_layers, + num_decoder_layers=config.num_decoder_layers, + is_encoder_only=is_encoder_only, + ) + state_dict = make_state_dict(converted, is_encoder_only) + model.load_state_dict(state_dict, strict=True) + + +def convert_timesfmx_checkpoint_to_pytorch( + timesfmx_checkpoint_path, config_file, pytorch_dump_path, is_encoder_only: bool = False +): + """Loads the config and model, converts the TimesFMX checkpoint, and saves a PyTorch checkpoint.""" + # Initialise PyTorch model + config = TimesFMConfig.from_json_file(config_file) + print(f"Building PyTorch model from configuration: {config}") + # Non-v1.1 checkpoints could also use TimesFMModel, but this works for all. + # The v1.0 checkpoints will simply have an LM head that is the word embeddings. + if is_encoder_only: + model = TimesFMEncoderModel(config) + else: + model = TimesFMForConditionalGeneration(config) + + # Load weights from tf checkpoint + load_timesfmx_weights_in_timesfm(model, config, timesfmx_checkpoint_path, is_encoder_only) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + model.save_pretrained(pytorch_dump_path) + + # Verify that we can load the checkpoint. + model.from_pretrained(pytorch_dump_path) + print("Done") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Converts a native TimesFMX checkpoint into a PyTorch checkpoint.") + # Required parameters + parser.add_argument( + "--timesfmx_checkpoint_path", default=None, type=str, required=True, help="Path to the TimesFMX checkpoint." + ) + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help="The config json file corresponding to the pre-trained TimesFM model.\nThis specifies the model architecture.", + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--is_encoder_only", action="store_true", help="Check if the model is encoder-decoder model", default=False + ) + args = parser.parse_args() + convert_timesfmx_checkpoint_to_pytorch( + args.timesfmx_checkpoint_path, args.config_file, args.pytorch_dump_path, args.is_encoder_only + ) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py new file mode 100644 index 000000000000..ea6daa33ac6b --- /dev/null +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -0,0 +1,2388 @@ +# coding=utf-8 +# Copyright 2024 Mesh TensorFlow authors, TimesFM Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch TimesFM model.""" + +import copy +import math +import os +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqQuestionAnsweringModelOutput, + Seq2SeqSequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_fx_proxy, + logging, + replace_return_docstrings, +) +from ...utils.model_parallel_utils import assert_device_map, get_device_map +from .configuration_timesfm import TimesFMConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "TimesFMConfig" +_CHECKPOINT_FOR_DOC = "google/timesfm-1.0-200m" + +#################################################### +# This dict contains ids and associated url +# for the pretrained weights provided with the models +#################################################### + + +#################################################### +# This is a conversion method from TF 1.0 to PyTorch +# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 +#################################################### +# Copied from transformers.models.t5.modeling_t5.load_tf_weights_in_t5 with t5->timesfm +def load_tf_weights_in_timesfm(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + tf_weights[name] = array + + for txt_name in names: + name = txt_name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + if "_slot_" in name[-1]: + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + pointer = model + array = tf_weights[txt_name] + + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + elif scope_names[0] == "self_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[0] + elif scope_names[0] == "enc_dec_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[1] + elif scope_names[0] == "dense_relu_dense": + pointer = getattr(pointer, "layer") + pointer = pointer[2] + elif scope_names[0] == "rms_norm": + if hasattr(pointer, "layer_norm"): + pointer = getattr(pointer, "layer_norm") + elif hasattr(pointer, "final_layer_norm"): + pointer = getattr(pointer, "final_layer_norm") + elif scope_names[0] == "scale": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + elif scope_names[0] == "decoder" and name[1] == "logits": + continue + elif scope_names[0] == "logits": + pointer = getattr(pointer, "lm_head") + elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): + pointer = getattr(pointer, f"wi_{scope_names[1]}") + continue + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if scope_names[0] not in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + if scope_names[0] != "embedding": + logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array.astype(np.float32)) + tf_weights.pop(txt_name, None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") + return model + + +#################################################### +# PyTorch Models are constructed by sub-classing +# - torch.nn.Module for the layers and +# - PreTrainedModel for the models (it-self a sub-class of nn.Module) +#################################################### +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, *optional*): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the timesfm models have the + following number of attention modules: + + - google/timesfm-1.0-200m: 6 + - google-timesfm/timesfm-base: 12 + - google-timesfm/timesfm-large: 24 + - google-timesfm/timesfm-3b: 24 + - google-timesfm/timesfm-11b: 24 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using google-timesfm/timesfm-3b, which has a total of 24 attention modules: + model = TimesFMForConditionalGeneration.from_pretrained("google-timesfm/timesfm-3b") + device_map = { + 0: [0, 1, 2], + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with google-timesfm/timesfm-3b: + model = TimesFMForConditionalGeneration.from_pretrained("google-timesfm/timesfm-3b") + device_map = { + 0: [0, 1, 2], + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +# Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->TimesFM +class TimesFMLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the TimesFM style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # TimesFM uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +try: + from apex.normalization import FusedRMSNorm + + TimesFMLayerNorm = FusedRMSNorm # noqa + + logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of TimesFMLayerNorm") +except ImportError: + # using the normal TimesFMLayerNorm + pass +except Exception: + logger.warning("discovered apex but it failed to load, falling back to TimesFMLayerNorm") + pass + +ALL_LAYERNORM_LAYERS.append(TimesFMLayerNorm) + + +# Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->TimesFM +class TimesFMDenseActDense(nn.Module): + def __init__(self, config: TimesFMConfig): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->TimesFM,t5->timesfm +class TimesFMDenseGatedActDense(nn.Module): + def __init__(self, config: TimesFMConfig): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + # To make 8bit quantization work for google/flan-timesfm-xxl, self.wo is kept in float32. + # See https://github.com/huggingface/transformers/issues/20287 + # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->TimesFM +class TimesFMLayerFF(nn.Module): + def __init__(self, config: TimesFMConfig): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = TimesFMDenseGatedActDense(config) + else: + self.DenseReluDense = TimesFMDenseActDense(config) + + self.layer_norm = TimesFMLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5Attention with T5->TimesFM +class TimesFMAttention(nn.Module): + def __init__(self, config: TimesFMConfig, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = False + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + if len(past_key_value) != 2: + raise ValueError( + f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + ) + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + def shape(states): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def unshape(states): + """reshape""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + elif past_key_value.shape[2] != key_value_states.shape[1]: + # checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + ) + value_states = project( + hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + ) + + # compute scores + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( + scores + ) # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) # (batch_size, n_heads, seq_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + + present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->TimesFM +class TimesFMLayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.SelfAttention = TimesFMAttention(config, has_relative_attention_bias=has_relative_attention_bias) + self.layer_norm = TimesFMLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->TimesFM +class TimesFMLayerCrossAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.EncDecAttention = TimesFMAttention(config, has_relative_attention_bias=False) + self.layer_norm = TimesFMLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5Block with T5->TimesFM +class TimesFMBlock(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append(TimesFMLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + if self.is_decoder: + self.layer.append(TimesFMLayerCrossAttention(config)) + + self.layer.append(TimesFMLayerFF(config)) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): + if past_key_value is not None: + if not self.is_decoder: + logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f"There should be {expected_num_past_key_values} past states. " + f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}" + f"Got {len(past_key_value)} past key / value states" + ) + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = present_key_value_state[0].shape[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = present_key_value_state + cross_attention_outputs[1] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (present_key_value_state,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +# Copied from transformers.models.t5.modeling_t5.T5ClassificationHead with T5->TimesFM +class TimesFMClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config: TimesFMConfig): + super().__init__() + self.dense = nn.Linear(config.d_model, config.d_model) + self.dropout = nn.Dropout(p=config.classifier_dropout) + self.out_proj = nn.Linear(config.d_model, config.num_labels) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel with T5->TimesFM,t5->timesfm +class TimesFMPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = TimesFMConfig + load_tf_weights = load_tf_weights_in_timesfm + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["TimesFMBlock"] + _keep_in_fp32_modules = ["wo"] + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "decoder_input_ids": input_ids, + "input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, TimesFMLayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance( + module, + (TimesFMModel, TimesFMForConditionalGeneration, TimesFMEncoderModel, TimesFMForQuestionAnswering), + ): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: + module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "qa_outputs"): + module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.qa_outputs.bias.data.zero_() + elif isinstance(module, TimesFMForTokenClassification): + if hasattr(module, "classifier"): + module.classifier.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.classifier.bias.data.zero_() + elif isinstance(module, TimesFMClassificationHead): + module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.dense, "bias") and module.dense.bias is not None: + module.dense.bias.data.zero_() + module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: + module.out_proj.bias.data.zero_() + elif isinstance(module, TimesFMDenseActDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, TimesFMDenseGatedActDense): + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, TimesFMAttention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + if decoder_start_token_id is None: + raise ValueError( + "self.model.config.decoder_start_token_id has to be defined. In TimesFM it is usually set to the pad_token_id. " + "See TimesFM docs for more information." + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +# Copied from transformers.models.t5.modeling_t5.T5Stack with T5->TimesFM +class TimesFMStack(TimesFMPreTrainedModel): + def __init__(self, config, embed_tokens=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.block = nn.ModuleList( + [TimesFMBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + ) + self.final_layer_norm = TimesFMLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.post_init() + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`TimesFMStack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model" + " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0," + " 'block.1': 1, ...}", + FutureWarning, + ) + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.block)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + # Load onto devices + for k, v in self.device_map.items(): + for layer in v: + cuda_device = "cuda:" + str(k) + self.block[layer] = self.block[layer].to(cuda_device) + + # Set embed_tokens to first layer + self.embed_tokens = self.embed_tokens.to(self.first_device) + # Set final layer norm to last device + self.final_layer_norm = self.final_layer_norm.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + for i in range(len(self.block)): + self.block[i] = self.block[i].to("cpu") + self.embed_tokens = self.embed_tokens.to("cpu") + self.final_layer_norm = self.final_layer_norm.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if inputs_embeds is None: + if self.embed_tokens is None: + raise ValueError("You have to initialize the model with valid token embeddings") + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + if use_cache is True: + if not self.is_decoder: + raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long + ) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.forward, + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +TIMESFM_START_DOCSTRING = r""" + + The TIMESFM model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text + Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan + Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a + text-to-text denoising generative setting. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`TimesFMConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +TIMESFM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. TIMESFM is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [TIMESFM Training](./timesfm#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + TIMESFM uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [TIMESFM + Training](./timesfm#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +TIMESFM_ENCODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. TIMESFM is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + To know more on how to prepare `input_ids` for pretraining take a look a [TIMESFM Training](./timesfm#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, +num_heads)`. +""" + + +@add_start_docstrings( + "The bare TIMESFM Model transformer outputting raw hidden-states without any specific head on top.", + TIMESFM_START_DOCSTRING, +) +class TimesFMModel(TimesFMPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [ + "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", + ] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: TimesFMConfig): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = TimesFMStack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = TimesFMStack(decoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`TimesFMModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model" + " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'encoder.block.0':" + " 0, 'encoder.block.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, TimesFMModel + + >>> tokenizer = AutoTokenizer.from_pretrained("google/timesfm-1.0-200m") + >>> model = TimesFMModel.from_pretrained("google/timesfm-1.0-200m") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for TimesFMModel. + >>> # This is not needed for torch's TimesFMForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings("""TIMESFM Model with a `language modeling` head on top.""", TIMESFM_START_DOCSTRING) +class TimesFMForConditionalGeneration(TimesFMPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [ + "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", + ] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + + def __init__(self, config: TimesFMConfig): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = TimesFMStack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = TimesFMStack(decoder_config, self.shared) + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`TimesFMForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you" + " should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also" + " provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance" + " {'encoder.block.0': 0, 'encoder.block.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.decoder.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., + config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for + labels in `[0, ..., config.vocab_size]` + + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, TimesFMForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("google/timesfm-1.0-200m") + >>> model = TimesFMForConditionalGeneration.from_pretrained("google/timesfm-1.0-200m") + + >>> # training + >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids + >>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> # inference + >>> input_ids = tokenizer( + ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model.generate(input_ids) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + >>> # studies have shown that owning a dog is good for you. + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + # move labels to correct device to enable PP + labels = labels.to(lm_logits.device) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + decoder_attention_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past_key_values is used + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + return { + "decoder_input_ids": input_ids, + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "decoder_attention_mask": decoder_attention_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past_key_values, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past_key_values is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past_key_values + + reordered_decoder_past = () + for layer_past_states in past_key_values: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), + ) + + if reordered_layer_past_states[0].shape != layer_past_states[0].shape: + raise ValueError( + f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched" + ) + if len(reordered_layer_past_states) != len(layer_past_states): + raise ValueError( + f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched" + ) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + + +@add_start_docstrings( + "The bare TIMESFM Model transformer outputting encoder's raw hidden-states without any specific head on top.", + TIMESFM_START_DOCSTRING, +) +class TimesFMEncoderModel(TimesFMPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight"] + _keys_to_ignore_on_load_unexpected = [r"decoder"] + + def __init__(self, config: TimesFMConfig): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = TimesFMStack(encoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`TimesFMEncoderModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load" + " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0," + " 'block.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.encoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(TIMESFM_ENCODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, TimesFMEncoderModel + + >>> tokenizer = AutoTokenizer.from_pretrained("google/timesfm-1.0-200m") + >>> model = TimesFMEncoderModel.from_pretrained("google/timesfm-1.0-200m") + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return encoder_outputs + + +@add_start_docstrings( + """ + TIMESFM model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE + tasks. + """, + TIMESFM_START_DOCSTRING, +) +class TimesFMForSequenceClassification(TimesFMPreTrainedModel): + _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: TimesFMConfig): + super().__init__(config) + self.transformer = TimesFMModel(config) + self.classification_head = TimesFMClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + self.model_parallel = False + + @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + if input_ids is None and inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}" + ) + + # decoder_input_ids from input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + decoder_input_ids = self._shift_right(input_ids) + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + + eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device) + + if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + batch_size, _, hidden_size = sequence_output.shape + sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :] + logits = self.classification_head(sentence_representation) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.config.num_labels == 1: + self.config.problem_type = "regression" + elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.config.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSequenceClassifierOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@add_start_docstrings( + """ + TIMESFM Encoder Model with a token classification head on top (a linear layer on top of the hidden-states output) + e.g. for Named-Entity-Recognition (NER) tasks. + """, + TIMESFM_START_DOCSTRING, +) +class TimesFMForTokenClassification(TimesFMPreTrainedModel): + _tied_weights_keys = ["transformer.encoder.embed_tokens.weight"] + + def __init__(self, config: TimesFMConfig): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = TimesFMEncoderModel(config) + self.dropout = nn.Dropout(config.classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits, outputs[2:-1]) + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + TIMESFM Model with a span classification head on top for extractive question-answering tasks like SQuAD (linear layers + on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + TIMESFM_START_DOCSTRING, +) +class TimesFMForQuestionAnswering(TimesFMPreTrainedModel): + _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: TimesFMConfig): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = TimesFMStack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = TimesFMStack(decoder_config, self.shared) + + self.num_labels = config.num_labels + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + self.model_parallel = False + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence + are not taken into account for computing the loss. + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache + if start_positions is not None and end_positions is not None: + use_cache = False + + # different to other models, TIMESFM automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + decoder_input_ids = self._shift_right(input_ids) + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=None, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + decoder_outputs[1:] + encoder_outputs + return ((total_loss,) + output) if total_loss is not None else output + + return Seq2SeqQuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) diff --git a/tests/models/timesfm/__init__.py b/tests/models/timesfm/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py new file mode 100644 index 000000000000..e5878f8c51c7 --- /dev/null +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -0,0 +1,1459 @@ +# coding=utf-8 +# Copyright 2024 Google TimesFM Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import copy +import os +import pickle +import tempfile +import unittest + +from transformers import TimesFMConfig, is_torch_available +from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES +from transformers.testing_utils import ( + require_accelerate, + require_sentencepiece, + require_tokenizers, + require_torch, + slow, + torch_device, +) +from transformers.utils import cached_property, is_torch_fx_available + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_fx_available(): + from transformers.utils.fx import symbolic_trace + + +if is_torch_available(): + import torch + + from transformers import ( + AutoTokenizer, + ByT5Tokenizer, + TimesFMEncoderModel, + TimesFMForConditionalGeneration, + TimesFMForQuestionAnswering, + TimesFMForSequenceClassification, + TimesFMForTokenClassification, + TimesFMModel, + T5Tokenizer, + ) + + +class TimesFMModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + encoder_seq_length=7, + decoder_seq_length=7, + # For common tests + is_training=True, + use_attention_mask=True, + use_labels=True, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + d_ff=37, + relative_attention_num_buckets=8, + dropout_rate=0.1, + initializer_factor=0.002, + eos_token_id=1, + pad_token_id=0, + decoder_start_token_id=0, + scope=None, + decoder_layers=None, + ): + self.parent = parent + self.batch_size = batch_size + self.encoder_seq_length = encoder_seq_length + self.decoder_seq_length = decoder_seq_length + # For common tests + self.seq_length = self.decoder_seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.d_ff = d_ff + self.relative_attention_num_buckets = relative_attention_num_buckets + self.dropout_rate = dropout_rate + self.initializer_factor = initializer_factor + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.decoder_start_token_id = decoder_start_token_id + self.scope = None + self.decoder_layers = decoder_layers + + def get_large_model_config(self): + return TimesFMConfig.from_pretrained("google-timesfm/timesfm-base") + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size).clamp(2) + input_ids[:, -1] = self.eos_token_id # Eos Token + decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + attention_mask = None + decoder_attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) + decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2) + + lm_labels = None + if self.use_labels: + lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + config = self.get_config() + + return ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) + + def get_pipeline_config(self): + return TimesFMConfig( + vocab_size=166, # timesfm forces 100 extra tokens + d_model=self.hidden_size, + d_ff=self.d_ff, + d_kv=self.hidden_size // self.num_attention_heads, + num_layers=self.num_hidden_layers, + num_decoder_layers=self.decoder_layers, + num_heads=self.num_attention_heads, + relative_attention_num_buckets=self.relative_attention_num_buckets, + dropout_rate=self.dropout_rate, + initializer_factor=self.initializer_factor, + eos_token_id=self.eos_token_id, + bos_token_id=self.pad_token_id, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + ) + + def get_config(self): + return TimesFMConfig( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + d_ff=self.d_ff, + d_kv=self.hidden_size // self.num_attention_heads, + num_layers=self.num_hidden_layers, + num_decoder_layers=self.decoder_layers, + num_heads=self.num_attention_heads, + relative_attention_num_buckets=self.relative_attention_num_buckets, + dropout_rate=self.dropout_rate, + initializer_factor=self.initializer_factor, + eos_token_id=self.eos_token_id, + bos_token_id=self.pad_token_id, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + ) + + def check_prepare_lm_labels_via_shift_left( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = TimesFMModel(config=config) + model.to(torch_device) + model.eval() + + # make sure that lm_labels are correctly padded from the right + lm_labels.masked_fill_((lm_labels == self.decoder_start_token_id), self.eos_token_id) + + # add casaul pad token mask + triangular_mask = torch.tril(lm_labels.new_ones(lm_labels.shape)).logical_not() + lm_labels.masked_fill_(triangular_mask, self.pad_token_id) + decoder_input_ids = model._shift_right(lm_labels) + + for i, (decoder_input_ids_slice, lm_labels_slice) in enumerate(zip(decoder_input_ids, lm_labels)): + # first item + self.parent.assertEqual(decoder_input_ids_slice[0].item(), self.decoder_start_token_id) + if i < decoder_input_ids_slice.shape[-1]: + if i < decoder_input_ids.shape[-1] - 1: + # items before diagonal + self.parent.assertListEqual( + decoder_input_ids_slice[1 : i + 1].tolist(), lm_labels_slice[:i].tolist() + ) + # pad items after diagonal + if i < decoder_input_ids.shape[-1] - 2: + self.parent.assertListEqual( + decoder_input_ids_slice[i + 2 :].tolist(), lm_labels_slice[i + 1 : -1].tolist() + ) + else: + # all items after square + self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist()) + + def create_and_check_model( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = TimesFMModel(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + decoder_output = result.last_hidden_state + decoder_past = result.past_key_values + encoder_output = result.encoder_last_hidden_state + + self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size)) + self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size)) + # There should be `num_layers` key value embeddings stored in decoder_past + self.parent.assertEqual(len(decoder_past), config.num_layers) + # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple + self.parent.assertEqual(len(decoder_past[0]), 4) + + def create_and_check_with_lm_head( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = TimesFMForConditionalGeneration(config=config).to(torch_device).eval() + outputs = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + labels=lm_labels, + ) + self.parent.assertEqual(len(outputs), 4) + self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size)) + self.parent.assertEqual(outputs["loss"].size(), ()) + + def create_and_check_with_sequence_classification_head( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + labels = torch.tensor([1] * self.batch_size, dtype=torch.long, device=torch_device) + model = TimesFMForSequenceClassification(config=config).to(torch_device).eval() + outputs = model( + input_ids=input_ids, + decoder_input_ids=input_ids, + labels=labels, + ) + # self.parent.assertEqual(len(outputs), 4) + self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, config.num_labels)) + self.parent.assertEqual(outputs["loss"].size(), ()) + + def create_and_check_decoder_model_past( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = TimesFMModel(config=config).get_decoder().to(torch_device).eval() + # first forward pass + outputs = model(input_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids) + outputs_no_past = model(input_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + output, past_key_values = outputs.to_tuple() + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + + output_from_no_past = model(next_input_ids)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_decoder_model_attention_mask_past( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = TimesFMModel(config=config).get_decoder() + model.to(torch_device) + model.eval() + + # create attention mask + attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + half_seq_length = input_ids.shape[-1] // 2 + attn_mask[:, half_seq_length:] = 0 + + # first forward pass + output, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True).to_tuple() + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) + input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens + + # append to next input_ids and attn_mask + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + attn_mask = torch.cat( + [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], + dim=1, + ) + + # get two different outputs + output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[ + "last_hidden_state" + ] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = TimesFMModel(config=config).get_decoder().to(torch_device).eval() + # first forward pass + outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + + output, past_key_values = outputs.to_tuple() + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([attention_mask, next_mask], dim=-1) + + output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"] + output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[ + "last_hidden_state" + ] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_generate_with_past_key_values( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = TimesFMForConditionalGeneration(config=config).to(torch_device).eval() + torch.manual_seed(0) + output_without_past_cache = model.generate( + input_ids[:1], num_beams=2, max_length=5, do_sample=True, use_cache=False + ) + torch.manual_seed(0) + output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True) + self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache)) + + def create_and_check_model_fp16_forward( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = TimesFMModel(config=config).to(torch_device).half().eval() + output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"] + self.parent.assertFalse(torch.isnan(output).any().item()) + + def create_and_check_encoder_decoder_shared_weights( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + for model_class in [TimesFMModel, TimesFMForConditionalGeneration]: + torch.manual_seed(0) + model = model_class(config=config).to(torch_device).eval() + # load state dict copies weights but does not tie them + model.encoder.load_state_dict(model.decoder.state_dict(), strict=False) + + torch.manual_seed(0) + tied_config = copy.deepcopy(config) + tied_config.tie_encoder_decoder = True + tied_model = model_class(config=tied_config).to(torch_device).eval() + + model_result = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + tied_model_result = tied_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + # check that models has less parameters + self.parent.assertLess( + sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) + ) + random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() + + # check that outputs are equal + self.parent.assertTrue( + torch.allclose( + model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4 + ) + ) + + # check that outputs after saving and loading are equal + with tempfile.TemporaryDirectory() as tmpdirname: + tied_model.save_pretrained(tmpdirname) + tied_model = model_class.from_pretrained(tmpdirname) + tied_model.to(torch_device) + tied_model.eval() + + # check that models has less parameters + self.parent.assertLess( + sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) + ) + random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() + + tied_model_result = tied_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + # check that outputs are equal + self.parent.assertTrue( + torch.allclose( + model_result[0][0, :, random_slice_idx], + tied_model_result[0][0, :, random_slice_idx], + atol=1e-4, + ) + ) + + def check_resize_embeddings_timesfm_v1_1( + self, + config, + ): + prev_vocab_size = config.vocab_size + + config.tie_word_embeddings = False + model = TimesFMForConditionalGeneration(config=config).to(torch_device).eval() + model.resize_token_embeddings(prev_vocab_size - 10) + + self.parent.assertEqual(model.get_input_embeddings().weight.shape[0], prev_vocab_size - 10) + self.parent.assertEqual(model.get_output_embeddings().weight.shape[0], prev_vocab_size - 10) + self.parent.assertEqual(model.config.vocab_size, prev_vocab_size - 10) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + "use_cache": False, + } + return config, inputs_dict + + +@require_torch +class TimesFMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = ( + (TimesFMModel, TimesFMForConditionalGeneration, TimesFMForSequenceClassification, TimesFMForQuestionAnswering) + if is_torch_available() + else () + ) + all_generative_model_classes = (TimesFMForConditionalGeneration,) if is_torch_available() else () + all_parallelizable_model_classes = (TimesFMModel, TimesFMForConditionalGeneration) if is_torch_available() else () + fx_compatible = False + test_pruning = False + test_resize_embeddings = True + test_model_parallel = True + is_encoder_decoder = True + # The small TimesFM model needs higher percentages for CPU/MP tests + model_split_percents = [0.5, 0.8, 0.9] + + def setUp(self): + self.model_tester = TimesFMModelTester(self) + self.config_tester = ConfigTester(self, config_class=TimesFMConfig, d_model=37) + + # TimesFMForSequenceClassification does not support inputs_embeds + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in (TimesFMModel, TimesFMForConditionalGeneration, TimesFMForQuestionAnswering): + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + + if not self.is_encoder_decoder: + input_ids = inputs["input_ids"] + del inputs["input_ids"] + else: + encoder_input_ids = inputs["input_ids"] + decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) + del inputs["input_ids"] + inputs.pop("decoder_input_ids", None) + + wte = model.get_input_embeddings() + if not self.is_encoder_decoder: + inputs["inputs_embeds"] = wte(input_ids) + else: + inputs["inputs_embeds"] = wte(encoder_input_ids) + inputs["decoder_inputs_embeds"] = wte(decoder_input_ids) + + with torch.no_grad(): + model(**inputs)[0] + + def test_config_and_model_silu_gated(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + config = config_and_inputs[0] + config.feed_forward_proj = "gated-silu" + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_with_lm_head(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_with_lm_head(*config_and_inputs) + + def test_with_sequence_classification_head(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_with_sequence_classification_head(*config_and_inputs) + + def test_decoder_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past(*config_and_inputs) + + def test_decoder_model_past_with_attn_mask(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) + + def test_decoder_model_past_with_3d_attn_mask(self): + ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) = self.model_tester.prepare_config_and_inputs() + + attention_mask = ids_tensor( + [self.model_tester.batch_size, self.model_tester.encoder_seq_length, self.model_tester.encoder_seq_length], + vocab_size=2, + ) + decoder_attention_mask = ids_tensor( + [self.model_tester.batch_size, self.model_tester.decoder_seq_length, self.model_tester.decoder_seq_length], + vocab_size=2, + ) + + self.model_tester.create_and_check_decoder_model_attention_mask_past( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) + + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + + def test_generate_with_past_key_values(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_generate_with_past_key_values(*config_and_inputs) + + def test_encoder_decoder_shared_weights(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_encoder_decoder_shared_weights(*config_and_inputs) + + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") + def test_model_fp16_forward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) + + def test_v1_1_resize_embeddings(self): + config = self.model_tester.prepare_config_and_inputs()[0] + self.model_tester.check_resize_embeddings_timesfm_v1_1(config) + + @slow + def test_model_from_pretrained(self): + model_name = "google/timesfm-1.0-200m" + model = TimesFMModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + @unittest.skip(reason="Test has a segmentation fault on torch 1.8.0") + def test_export_to_onnx(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + model = TimesFMModel(config_and_inputs[0]).to(torch_device) + with tempfile.TemporaryDirectory() as tmpdirname: + torch.onnx.export( + model, + (config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]), + f"{tmpdirname}/timesfm_test.onnx", + export_params=True, + opset_version=9, + input_names=["input_ids", "decoder_input_ids"], + ) + + def test_generate_with_head_masking(self): + attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] + config_and_inputs = self.model_tester.prepare_config_and_inputs() + config = config_and_inputs[0] + max_length = config_and_inputs[1].shape[-1] + 3 + model = TimesFMForConditionalGeneration(config).eval() + model.to(torch_device) + + head_masking = { + "head_mask": torch.zeros(config.num_layers, config.num_heads, device=torch_device), + "decoder_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device), + "cross_attn_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device), + } + + for attn_name, (name, mask) in zip(attention_names, head_masking.items()): + head_masks = {name: mask} + # Explicitly pass decoder_head_mask as it is required from TimesFM model when head_mask specified + if name == "head_mask": + head_masks["decoder_head_mask"] = torch.ones( + config.num_decoder_layers, config.num_heads, device=torch_device + ) + + out = model.generate( + config_and_inputs[1], + num_beams=1, + max_length=max_length, + output_attentions=True, + return_dict_in_generate=True, + **head_masks, + ) + # We check the state of decoder_attentions and cross_attentions just from the last step + attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1] + self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0) + + +class TimesFMEncoderOnlyModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + encoder_seq_length=7, + # For common tests + use_attention_mask=True, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + d_ff=37, + relative_attention_num_buckets=8, + is_training=False, + dropout_rate=0.1, + initializer_factor=0.002, + is_encoder_decoder=False, + eos_token_id=1, + pad_token_id=0, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.encoder_seq_length = encoder_seq_length + # For common tests + self.seq_length = self.encoder_seq_length + self.use_attention_mask = use_attention_mask + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.d_ff = d_ff + self.relative_attention_num_buckets = relative_attention_num_buckets + self.dropout_rate = dropout_rate + self.initializer_factor = initializer_factor + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.is_encoder_decoder = is_encoder_decoder + self.scope = None + self.is_training = is_training + + def get_large_model_config(self): + return TimesFMConfig.from_pretrained("google-timesfm/timesfm-base") + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) + + config = TimesFMConfig( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + d_ff=self.d_ff, + d_kv=self.hidden_size // self.num_attention_heads, + num_layers=self.num_hidden_layers, + num_heads=self.num_attention_heads, + relative_attention_num_buckets=self.relative_attention_num_buckets, + dropout_rate=self.dropout_rate, + initializer_factor=self.initializer_factor, + eos_token_id=self.eos_token_id, + bos_token_id=self.pad_token_id, + pad_token_id=self.pad_token_id, + is_encoder_decoder=self.is_encoder_decoder, + ) + + return ( + config, + input_ids, + attention_mask, + ) + + def create_and_check_model( + self, + config, + input_ids, + attention_mask, + ): + model = TimesFMEncoderModel(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids=input_ids, + attention_mask=attention_mask, + ) + result = model(input_ids=input_ids) + encoder_output = result.last_hidden_state + + self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size)) + + def create_and_check_model_fp16_forward( + self, + config, + input_ids, + attention_mask, + ): + model = TimesFMEncoderModel(config=config).to(torch_device).half().eval() + output = model(input_ids, attention_mask=attention_mask)["last_hidden_state"] + self.parent.assertFalse(torch.isnan(output).any().item()) + + def create_and_check_with_token_classification_head( + self, + config, + input_ids, + attention_mask, + ): + labels = torch.tensor([1] * self.seq_length * self.batch_size, dtype=torch.long, device=torch_device) + model = TimesFMForTokenClassification(config=config).to(torch_device).eval() + outputs = model( + input_ids=input_ids, + labels=labels, + attention_mask=attention_mask, + ) + self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.seq_length, config.num_labels)) + self.parent.assertEqual(outputs["loss"].size(), ()) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + attention_mask, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +class TimesFMEncoderOnlyModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (TimesFMEncoderModel, TimesFMForTokenClassification) if is_torch_available() else () + test_pruning = False + test_resize_embeddings = False + test_model_parallel = True + pipeline_model_mapping = ( + { + "token-classification": TimesFMForTokenClassification, + } + if is_torch_available() + else {} + ) + all_parallelizable_model_classes = (TimesFMEncoderModel,) if is_torch_available() else () + + def setUp(self): + self.model_tester = TimesFMEncoderOnlyModelTester(self) + self.config_tester = ConfigTester(self, config_class=TimesFMConfig, d_model=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") + def test_model_fp16_forward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) + + def test_with_token_classification_head(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_with_token_classification_head(*config_and_inputs) + + +def use_task_specific_params(model, task): + model.config.update(model.config.task_specific_params[task]) + + +@require_torch +@require_accelerate +@require_tokenizers +@slow +class TimesFMModelFp16Tests(unittest.TestCase): + def test_fp16_fp32_conversion(self): + r""" + A test to check whether the argument `keep_in_fp32_modules` correctly does its job + """ + orig_import = __import__ + accelerate_mock = unittest.mock.Mock() + + # mock import of accelerate + def import_accelerate_mock(name, *args, **kwargs): + if name == "accelerate": + if accelerate_available: + return accelerate_mock + else: + raise ImportError + return orig_import(name, *args, **kwargs) + + # Load without using `accelerate` + with unittest.mock.patch("builtins.__import__", side_effect=import_accelerate_mock): + accelerate_available = False + + model = TimesFMForConditionalGeneration.from_pretrained("google/timesfm-1.0-200m", torch_dtype=torch.float16) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16) + + # Load without in bf16 + model = TimesFMForConditionalGeneration.from_pretrained("google/timesfm-1.0-200m", torch_dtype=torch.bfloat16) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) + + # Load using `accelerate` in bf16 + model = TimesFMForConditionalGeneration.from_pretrained( + "google/timesfm-1.0-200m", torch_dtype=torch.bfloat16, device_map="auto" + ) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) + + # Load using `accelerate` in bf16 + model = TimesFMForConditionalGeneration.from_pretrained( + "google/timesfm-1.0-200m", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True + ) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) + + # Load without using `accelerate` + model = TimesFMForConditionalGeneration.from_pretrained( + "google/timesfm-1.0-200m", torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16) + + # Load using `accelerate` + model = TimesFMForConditionalGeneration.from_pretrained( + "google/timesfm-1.0-200m", torch_dtype=torch.float16, device_map="auto" + ) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16) + + +@require_torch +@require_sentencepiece +@require_tokenizers +class TimesFMModelIntegrationTests(unittest.TestCase): + @cached_property + def model(self): + return TimesFMForConditionalGeneration.from_pretrained("google-timesfm/timesfm-base").to(torch_device) + + @cached_property + def tokenizer(self): + return T5Tokenizer.from_pretrained("google-timesfm/timesfm-base") + + @slow + def test_torch_quant(self): + r""" + Test that a simple `torch.quantization.quantize_dynamic` call works on a TimesFM model. + """ + model_name = "google/flan-timesfm-small" + tokenizer = T5Tokenizer.from_pretrained(model_name) + model = TimesFMForConditionalGeneration.from_pretrained(model_name) + model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) + input_text = "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?" + input_ids = tokenizer(input_text, return_tensors="pt").input_ids + _ = model.generate(input_ids) + + @slow + def test_small_generation(self): + model = TimesFMForConditionalGeneration.from_pretrained("google/timesfm-1.0-200m").to(torch_device) + model.config.max_length = 8 + model.config.num_beams = 1 + model.config.do_sample = False + tokenizer = T5Tokenizer.from_pretrained("google/timesfm-1.0-200m") + + input_ids = tokenizer("summarize: Hello there", return_tensors="pt").input_ids.to(torch_device) + + sequences = model.generate(input_ids) + + output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] + self.assertTrue(output_str == "Hello there!") + + @slow + def test_small_integration_test(self): + """ + For comparision run: + >>> import timesfm # pip install timesfm==0.7.1 + >>> from timesfm.data.sentencepiece_vocabulary import SentencePieceVocabulary + + >>> path_to_mtf_small_timesfm_checkpoint = '' + >>> path_to_mtf_small_spm_model_path = '' + >>> timesfm_model = timesfm.models.MtfModel(model_dir=path_to_mtf_small_timesfm_checkpoint, batch_size=1, tpu=None) + >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) + >>> score = timesfm_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + """ + + model = TimesFMForConditionalGeneration.from_pretrained("google/timesfm-1.0-200m").to(torch_device) + tokenizer = T5Tokenizer.from_pretrained("google/timesfm-1.0-200m") + + input_ids = tokenizer("Hello there", return_tensors="pt").input_ids + labels = tokenizer("Hi I am", return_tensors="pt").input_ids + + loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss + mtf_score = -(labels.shape[-1] * loss.item()) + + EXPECTED_SCORE = -19.0845 + self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) + + @slow + def test_small_v1_1_integration_test(self): + """ + For comparision run: + >>> import timesfm # pip install timesfm==0.7.1 + >>> from timesfm.data.sentencepiece_vocabulary import SentencePieceVocabulary + + >>> path_to_mtf_small_timesfm_v1_1_checkpoint = '' + >>> path_to_mtf_small_spm_model_path = '' + >>> timesfm_model = timesfm.models.MtfModel(model_dir=path_to_mtf_small_timesfm_v1_1_checkpoint, batch_size=1, tpu=None) + >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) + >>> score = timesfm_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + """ + + model = TimesFMForConditionalGeneration.from_pretrained("google/timesfm-v1_1-small").to(torch_device) + tokenizer = T5Tokenizer.from_pretrained("google/timesfm-v1_1-small") + + input_ids = tokenizer("Hello there", return_tensors="pt").input_ids + labels = tokenizer("Hi I am", return_tensors="pt").input_ids + + loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss + mtf_score = -(labels.shape[-1] * loss.item()) + + EXPECTED_SCORE = -59.0293 + self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) + + @slow + def test_small_bytimesfm_integration_test(self): + """ + For comparision run: + >>> import timesfm # pip install timesfm==0.9.1 + + >>> path_to_bytimesfm_small_checkpoint = '' + >>> timesfm_model = timesfm.models.MtfModel(model_dir=path_to_tf_checkpoint, batch_size=1, tpu=None) + >>> vocab = timesfm.data.ByteVocabulary() + >>> score = timesfm_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + """ + + model = TimesFMForConditionalGeneration.from_pretrained("google/bytimesfm-small").to(torch_device) + tokenizer = ByT5Tokenizer.from_pretrained("google/bytimesfm-small") + + input_ids = tokenizer("Hello there", return_tensors="pt").input_ids + labels = tokenizer("Hi I am", return_tensors="pt").input_ids + + loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss + mtf_score = -(labels.shape[-1] * loss.item()) + + EXPECTED_SCORE = -60.7397 + self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) + + @slow + def test_summarization(self): + model = self.model + tok = self.tokenizer + + FRANCE_ARTICLE = ( # @noqa + "Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings" + " Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane." + ' Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation."' + ' He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s' + " comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video" + " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French" + " Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a" + " phone at the wreckage site. The two publications described the supposed video, but did not post it on" + " their websites. The publications said that they watched the video, which was found by a source close to" + " the investigation. \"One can hear cries of 'My God' in several languages,\" Paris Match reported." + ' "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the' + " cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the" + ' screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt,' + " editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said" + " the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman" + " in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the" + ' reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said,' + ' but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be' + " sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by" + " specialized technicians working hand-in-hand with investigators. But none of the cell phones found so" + " far have been sent to the institute, Menichini said. Asked whether staff involved in the search could" + ' have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin' + ' Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match' + ' are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered' + ' cell phones from the crash site after Bild and Paris Match published their reports. "That is something' + " we did not know before. ... Overall we can say many things of the investigation weren't revealed by the" + ' investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline' + " Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the" + " controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the" + ' French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of' + ' severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school' + " discovered in an internal investigation, Lufthansa said, included medical documents he submitted in" + " connection with resuming his flight training. The announcement indicates that Lufthansa, the parent" + " company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and" + " ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100%" + ' fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was' + " sharing the information and documents -- including training and medical records -- with public" + " prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the" + " past week to recover human remains and plane debris scattered across a steep mountainside. He saw the" + " crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash" + " site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late" + " Tuesday that no visible human remains were left at the site but recovery teams would keep searching." + " French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all" + " the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested." + " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said." + " Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew" + " on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with" + " the flight school during his training were among several developments as investigators continued to" + " delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa" + " spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his" + ' examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in' + " Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at" + " some point before his aviation career and underwent psychotherapy before he got his pilot's license." + " Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the" + " crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to" + " lose his pilot's license, a European government official briefed on the investigation told CNN on" + ' Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being' + " considered. Another source, a law enforcement official briefed on the investigation, also told CNN that" + " authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would" + " not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had" + " seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded" + " he had psychological issues, the European government official said. But no matter what details emerge" + " about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic" + ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact' + " that maybe they weren't going to keep doing their job and they're upset about that and so they're" + ' suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to' + " also take that rage and turn it outward on 149 other people who had nothing to do with the person's" + ' problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight' + " 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura" + " Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine" + " Amiel and Anna-Maja Rappard contributed to this report." + ) + SHORTER_ARTICLE = ( + "(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on" + " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The" + " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based." + " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its" + ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East' + ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the' + " situation in Palestinian territories, paving the way for possible war crimes investigations against" + " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and" + " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the" + " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a" + ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the' + ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an' + ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge' + " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the" + ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine' + " acquires all the rights as well as responsibilities that come with being a State Party to the Statute." + ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights' + ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should' + " immediately end their pressure, and countries that support universal acceptance of the court's treaty" + ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the' + " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's" + ' decision to join a treaty to which over 100 countries around the world are members." In January, when' + " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an" + ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"' + " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a" + ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in' + ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We' + ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"' + " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the" + ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the' + " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou" + ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war' + " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry" + " will include alleged war crimes committed since June. The International Criminal Court was set up in" + " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder" + " and Faith Karimi contributed to this report." + ) + IRAN_ARTICLE = ( + "(CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran" + " in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively" + " block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger." + " Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli" + " Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a" + " letter to the Iranian leadership warning them away from a deal. The debate that has already begun since" + " the announcement of the new framework will likely result in more heat than light. It will not be helped" + " by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: ." + " The most misleading assertion, despite universal rejection by experts, is that the negotiations'" + " objective at the outset was the total elimination of any nuclear program in Iran. That is the position" + " of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it" + " had been, there would have been no Iranian team at the negotiating table. Rather, the objective has" + " always been to structure an agreement or series of agreements so that Iran could not covertly develop a" + " nuclear arsenal before the United States and its allies could respond. The new framework has exceeded" + " expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by" + " two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another" + " dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite" + " sharp accusations by some in the United States and its allies, Iran denies having such a program, and" + " U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's" + " continued cooperation with International Atomic Energy Agency inspections is further evidence on this" + " point, and we'll know even more about Iran's program in the coming months and years because of the deal." + " In fact, the inspections provisions that are part of this agreement are designed to protect against any" + " covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that" + " the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter" + " warning that a deal might be killed by Congress or a future president). This of course is not the case." + " The talks were between Iran and the five permanent members of the U.N. Security Council (United States," + " United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has" + " played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement" + " reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran" + " and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement" + " contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the" + " case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased" + " or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes" + " Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear" + " sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going" + " forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such" + " a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the" + ' agreement should be a formal treaty requiring the Senate to "advise and consent." But the issue is not' + " suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New" + " START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement" + " with Iran will not be so balanced. The restrictions and obligations in the final framework agreement" + " will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove" + " most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally" + " some insist that any agreement must address Iranian missile programs, human rights violations or support" + " for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are" + " unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in" + " the negotiations would be a poison pill. This agreement should be judged on its merits and on how it" + " affects the security of our negotiating partners and allies, including Israel. Those judgments should be" + " fact-based, not based on questionable assertions or dubious assumptions." + ) + ARTICLE_SUBWAY = ( + "New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A" + " year later, she got married again in Westchester County, but to a different man and without divorcing" + " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos" + ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married' + " once more, this time in the Bronx. In an application for a marriage license, she stated it was her" + ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false' + ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage' + " license application, according to court documents. Prosecutors said the marriages were part of an" + " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to" + " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was" + " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New" + " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total," + " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All" + " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be" + " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors" + " said the immigration scam involved some of her husbands, who filed for permanent residence status" + " shortly after the marriages. Any divorces happened only after such filings were approved. It was" + " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District" + " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's" + ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,' + " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his" + " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces" + " up to four years in prison. Her next court appearance is scheduled for May 18." + ) + + expected_summaries = [ + 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a' + " cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one" + " magazine says .", + "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a" + " preliminary examination into the situation in the occupied Palestinian territory . as members of the" + " court, Palestinians may be subject to counter-charges as well .", + "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller:" + " the debate that has already begun since the announcement of the new framework will likely result in more" + " heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and" + " implement a rigorous inspection regime .", + "prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two" + ' criminal counts of "offering a false instrument for filing in the first degree" she has been married 10' + " times, with nine of her marriages occurring between 1999 and 2002 .", + ] + + use_task_specific_params(model, "summarization") + + dct = tok( + [model.config.prefix + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]], + padding="max_length", + truncation=True, + return_tensors="pt", + ).to(torch_device) + self.assertEqual(512, dct["input_ids"].shape[1]) + + hypotheses_batch = model.generate( + **dct, + num_beams=4, + length_penalty=2.0, + max_length=142, + min_length=56, + no_repeat_ngram_size=3, + do_sample=False, + early_stopping=True, + ) + + decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False) + self.assertListEqual( + expected_summaries, + decoded, + ) + + @slow + def test_translation_en_to_de(self): + model = self.model + tok = self.tokenizer + use_task_specific_params(model, "translation_en_to_de") + + en_text = '"Luigi often said to me that he never wanted the brothers to end up in court", she wrote.' + expected_translation = ( + '"Luigi sagte mir oft, dass er nie wollte, dass die Brüder am Gericht sitzen", schrieb sie.' + ) + + input_ids = tok.encode(model.config.prefix + en_text, return_tensors="pt") + input_ids = input_ids.to(torch_device) + output = model.generate(input_ids) + translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) + self.assertEqual(translation, expected_translation) + + @slow + def test_translation_en_to_fr(self): + model = self.model # google-timesfm/timesfm-base + tok = self.tokenizer + use_task_specific_params(model, "translation_en_to_fr") + + en_text = ( + ' This image section from an infrared recording by the Spitzer telescope shows a "family portrait" of' + " countless generations of stars: the oldest stars are seen as blue dots. " + ) + + input_ids = tok.encode(model.config.prefix + en_text, return_tensors="pt") + input_ids = input_ids.to(torch_device) + + output = model.generate( + input_ids=input_ids, + num_beams=4, + length_penalty=2.0, + max_length=100, + no_repeat_ngram_size=3, + do_sample=False, + early_stopping=True, + ) + translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) + new_truncated_translation = ( + "Cette section d'images provenant de l'enregistrement infrarouge effectué par le télescope Spitzer montre " + "un " + "« portrait familial » de générations innombrables d’étoiles : les plus anciennes sont observées " + "sous forme " + "de points bleus." + ) + + self.assertEqual(translation, new_truncated_translation) + + @slow + def test_translation_en_to_ro(self): + model = self.model + tok = self.tokenizer + use_task_specific_params(model, "translation_en_to_ro") + en_text = "Taco Bell said it plans to add 2,000 locations in the US by 2022." + expected_translation = "Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022." + + inputs = tok(model.config.prefix + en_text, return_tensors="pt").to(torch_device) + output = model.generate(**inputs) + translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) + self.assertEqual(translation, expected_translation) + + @slow + def test_contrastive_search_timesfm(self): + article = ( + " New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A" + " year later, she got married again in Westchester County, but to a different man and without divorcing" + " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos" + ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married' + " once more, this time in the Bronx. In an application for a marriage license, she stated it was her" + ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false' + ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage' + " license application, according to court documents. Prosecutors said the marriages were part of an" + " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to" + " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was" + " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New" + " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total," + " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All" + " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be" + " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors" + " said the immigration scam involved some of her husbands, who filed for permanent residence status" + " shortly after the marriages. Any divorces happened only after such filings were approved. It was" + " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District" + " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's" + ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,' + " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his" + " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces" + " up to four years in prison. Her next court appearance is scheduled for May 18." + ) + article = "summarize: " + article.strip() + timesfm_tokenizer = AutoTokenizer.from_pretrained("flax-community/timesfm-base-cnn-dm") + timesfm_model = TimesFMForConditionalGeneration.from_pretrained("flax-community/timesfm-base-cnn-dm").to(torch_device) + input_ids = timesfm_tokenizer( + article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt" + ).input_ids.to(torch_device) + + outputs = timesfm_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64) + generated_text = timesfm_tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual( + generated_text, + [ + "Liana Barrientos has been married 10 times, nine of them in the Bronx. Her husbands filed for " + "permanent residence after the marriages, prosecutors say." + ], + ) + + +@require_torch +class TestAsymmetricTimesFM(unittest.TestCase): + def build_model_and_check_forward_pass(self, **kwargs): + tester = TimesFMModelTester(self, **kwargs) + config, *inputs = tester.prepare_config_and_inputs() + ( + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) = inputs + model = TimesFMForConditionalGeneration(config=config).to(torch_device).eval() + outputs = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + labels=lm_labels, + ) + # outputs = model(*inputs) + assert len(outputs) == 4 + assert outputs["logits"].size() == (tester.batch_size, tester.decoder_seq_length, tester.vocab_size) + assert outputs["loss"].size() == () + return model + + def test_small_decoder(self): + # num_hidden_layers is passed to TimesFMConfig as num_layers + model = self.build_model_and_check_forward_pass(decoder_layers=1, num_hidden_layers=2) + assert len(model.encoder.block) == 2 + assert len(model.decoder.block) == 1 + + def test_defaulting_to_symmetry(self): + # num_hidden_layers is passed to TimesFMConfig as num_layers + model = self.build_model_and_check_forward_pass(num_hidden_layers=2) + assert len(model.decoder.block) == len(model.encoder.block) == 2 From f43a0df6a62c6ea099c0cf2ce5e5ac3bdb3ff7fd Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 27 Aug 2024 23:25:52 +0200 Subject: [PATCH 084/242] added config and attention layers --- .../models/timesfm/configuration_timesfm.py | 15 ++--- .../models/timesfm/modeling_timesfm.py | 67 +++++++------------ 2 files changed, 29 insertions(+), 53 deletions(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 065f779b557c..ad66b752b3b3 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -35,9 +35,6 @@ class TimesFMConfig(PretrainedConfig): documentation from [`PretrainedConfig`] for more information. Arguments: - vocab_size (`int`, *optional*, defaults to 32128): - Vocabulary size of the TimesFM model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`TimesFMModel`] or [`TFTimesFMModel`]. d_model (`int`, *optional*, defaults to 512): Size of the encoder layers and the pooler layer. d_kv (`int`, *optional*, defaults to 64): @@ -77,13 +74,12 @@ class TimesFMConfig(PretrainedConfig): def __init__( self, - vocab_size=32128, - d_model=512, - d_kv=64, - d_ff=2048, - num_layers=6, + d_model=1280, + d_kv=80, + d_ff=1280, + num_layers=20, num_decoder_layers=None, - num_heads=8, + num_heads=16, relative_attention_num_buckets=32, relative_attention_max_distance=128, dropout_rate=0.1, @@ -97,7 +93,6 @@ def __init__( classifier_dropout=0.0, **kwargs, ): - self.vocab_size = vocab_size self.d_model = d_model self.d_kv = d_kv self.d_ff = d_ff diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index ea6daa33ac6b..ee4701cdaded 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -21,6 +21,7 @@ from typing import List, Optional, Tuple, Union import torch +import torch.nn.functional as F from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -270,12 +271,11 @@ def forward(self, hidden_states): ALL_LAYERNORM_LAYERS.append(TimesFMLayerNorm) -# Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->TimesFM class TimesFMDenseActDense(nn.Module): def __init__(self, config: TimesFMConfig): super().__init__() - self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) - self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.wi = nn.Linear(config.d_model, config.d_ff, bias=True) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=True) self.dropout = nn.Dropout(config.dropout_rate) self.act = ACT2FN[config.dense_act_fn] @@ -293,56 +293,35 @@ def forward(self, hidden_states): return hidden_states -# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->TimesFM,t5->timesfm -class TimesFMDenseGatedActDense(nn.Module): +class TimesFMLayerFF(nn.Module): def __init__(self, config: TimesFMConfig): super().__init__() - self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) - self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) - self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + + self.DenseReluDense = TimesFMDenseActDense(config) + self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) - self.act = ACT2FN[config.dense_act_fn] def forward(self, hidden_states): - hidden_gelu = self.act(self.wi_0(hidden_states)) - hidden_linear = self.wi_1(hidden_states) - hidden_states = hidden_gelu * hidden_linear - hidden_states = self.dropout(hidden_states) - - # To make 8bit quantization work for google/flan-timesfm-xxl, self.wo is kept in float32. - # See https://github.com/huggingface/transformers/issues/20287 - # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` - if ( - isinstance(self.wo.weight, torch.Tensor) - and hidden_states.dtype != self.wo.weight.dtype - and self.wo.weight.dtype != torch.int8 - ): - hidden_states = hidden_states.to(self.wo.weight.dtype) - - hidden_states = self.wo(hidden_states) + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) return hidden_states -# Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->TimesFM -class TimesFMLayerFF(nn.Module): +class TimesFMPerHeadDimScale(nn.Module): def __init__(self, config: TimesFMConfig): super().__init__() - if config.is_gated_act: - self.DenseReluDense = TimesFMDenseGatedActDense(config) - else: - self.DenseReluDense = TimesFMDenseActDense(config) - self.layer_norm = TimesFMLayerNorm(config.d_model, eps=config.layer_norm_epsilon) - self.dropout = nn.Dropout(config.dropout_rate) + self.dim = config.d_model // config.num_heads + self.scale = nn.Parameter(torch.zeros(self.dim)) def forward(self, hidden_states): - forwarded_states = self.layer_norm(hidden_states) - forwarded_states = self.DenseReluDense(forwarded_states) - hidden_states = hidden_states + self.dropout(forwarded_states) - return hidden_states + r_softplus_0 = 1.442695041 + scale = r_softplus_0 / math.sqrt(self.dim) + scale *= F.softplus(self.scale) + return hidden_states * scale -# Copied from transformers.models.t5.modeling_t5.T5Attention with T5->TimesFM class TimesFMAttention(nn.Module): def __init__(self, config: TimesFMConfig, has_relative_attention_bias=False): super().__init__() @@ -357,10 +336,11 @@ def __init__(self, config: TimesFMConfig, has_relative_attention_bias=False): self.inner_dim = self.n_heads * self.key_value_proj_dim # Mesh TensorFlow initialization to avoid scaling before softmax - self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) - self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) - self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) - self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + self.q = nn.Linear(self.d_model, self.inner_dim, bias=True) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=True) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=True) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=True) + self.per_head_dim_scale = TimesFMPerHeadDimScale(config) if self.has_relative_attention_bias: self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) @@ -515,7 +495,8 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): return hidden_states # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + unscaled_query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + query_states = self.per_head_dim_scale(unscaled_query_states) # get key/value states key_states = project( From 8bbda0644a66ddbc82b99d28bf2291edb899d4dd Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 27 Aug 2024 23:33:36 +0200 Subject: [PATCH 085/242] add TimesFMPositionalEmbedding --- .../models/timesfm/modeling_timesfm.py | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index ee4701cdaded..4f45d4f4e8c8 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -271,6 +271,59 @@ def forward(self, hidden_states): ALL_LAYERNORM_LAYERS.append(TimesFMLayerNorm) +class TimesFMPositionalEmbedding(nn.Module): + """Generates position embedding for a given 1-d sequence. + + Attributes: + min_timescale: Start of the geometric index. Determines the periodicity of + the added signal. + max_timescale: End of the geometric index. Determines the frequency of the + added signal. + embedding_dims: Dimension of the embedding to be generated. + """ + + def __init__(self, min_timescale=1, max_timescale=10000, embedding_dims=0): + super().__init__() + + self.min_timescale = min_timescale + self.max_timescale = max_timescale + self.embedding_dims = embedding_dims + + def forward(self, seq_length=None, position=None): + """Generates a tensor of sinusoids with different frequencies. + + Args: + seq_length: an optional Python int defining the output sequence length. + if the `position` argument is specified. + position: [B, seq_length], optional position for each token in the + sequence, only required when the sequence is packed. + + Returns: + [B, seqlen, D] if `position` is specified, else [1, seqlen, D] + """ + if position is None: + if seq_length is None: + raise ValueError("If position is None, seq_length should be specified.") + # [1, seqlen] + position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0) + else: + if position.ndim != 2: + raise ValueError(f"position should have 2 dimensions, got {position.ndim}") + + num_timescales = self.embedding_dims // 2 + log_timescale_increment = math.log(float(self.max_timescale) / float(self.min_timescale)) / max( + torch.tensor(num_timescales, dtype=torch.float32) - 1, 1 + ) + inv_timescales = self.min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(2) * inv_timescales.unsqueeze(0).unsqueeze(0) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2).type(torch.float32) + + signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) + return signal + + class TimesFMDenseActDense(nn.Module): def __init__(self, config: TimesFMConfig): super().__init__() From 5178c111ae0c8ccd5fa5c3e0d7d1887c7dbcef44 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 31 Aug 2024 11:31:20 +0200 Subject: [PATCH 086/242] calcuate scale_factor once --- .../models/timesfm/modeling_timesfm.py | 23 ++++++------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 4f45d4f4e8c8..8b338d7ae153 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -364,14 +364,13 @@ def forward(self, hidden_states): class TimesFMPerHeadDimScale(nn.Module): def __init__(self, config: TimesFMConfig): super().__init__() - - self.dim = config.d_model // config.num_heads - self.scale = nn.Parameter(torch.zeros(self.dim)) + dim = config.d_model // config.num_heads + r_softplus_0 = 1.442695041 + self.scale_factor = r_softplus_0 / math.sqrt(dim) + self.scale = nn.Parameter(torch.empty(self.dim)) def forward(self, hidden_states): - r_softplus_0 = 1.442695041 - scale = r_softplus_0 / math.sqrt(self.dim) - scale *= F.softplus(self.scale) + scale = self.scale_factor * F.softplus(self.scale) return hidden_states * scale @@ -890,16 +889,8 @@ def _init_weights(self, module): module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: module.wo.bias.data.zero_() - elif isinstance(module, TimesFMDenseGatedActDense): - module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) - if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: - module.wi_0.bias.data.zero_() - module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) - if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: - module.wi_1.bias.data.zero_() - module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) - if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.data.zero_() + elif isinstance(module, TimesFMPerHeadDimScale): + module.scale.data.zero_() elif isinstance(module, TimesFMAttention): # Mesh TensorFlow attention initialization to avoid scaling before softmax # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 From 95a06a9a6ad8c35aa6ab83b7432c8da3e8f0a8b2 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 31 Aug 2024 12:08:30 +0200 Subject: [PATCH 087/242] add more configs and TimesFMResidualBlock --- .../models/timesfm/configuration_timesfm.py | 28 +++++++++++++- .../models/timesfm/modeling_timesfm.py | 37 ++++++++++++++++--- 2 files changed, 58 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index ad66b752b3b3..933842f977ad 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -14,7 +14,7 @@ # limitations under the License. """TimesFM model configuration""" -from typing import Mapping +from typing import List, Mapping from ...configuration_utils import PretrainedConfig from ...onnx import OnnxSeq2SeqConfigWithPast @@ -35,12 +35,24 @@ class TimesFMConfig(PretrainedConfig): documentation from [`PretrainedConfig`] for more information. Arguments: + patch_len (`int`, *optional*, defaults to 32): + The length of each patch in the sequence. + horizon_len (`int`, *optional*, defaults to 128): + The length of the prediction horizon. + quantiles (`List[float]`, *optional*, defaults to `[0.1, 0.25, 0.5, 0.75, 0.9]`): + The quantiles to predict. + pad_val (`float`, *optional*, defaults to 1123581321.0): + The value used to pad the predictions. + tolerance (`float`, *optional*, defaults to 1e-6): + The tolerance for the quantile loss. + freq_size (`int`, *optional*, defaults to 3): + The number of frequency embeddings. d_model (`int`, *optional*, defaults to 512): Size of the encoder layers and the pooler layer. d_kv (`int`, *optional*, defaults to 64): Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will be defined as `num_heads * d_kv`. - d_ff (`int`, *optional*, defaults to 2048): + d_ff (`int`, *optional*, defaults to 1280): Size of the intermediate feed forward layer in each `TimesFMBlock`. num_layers (`int`, *optional*, defaults to 6): Number of hidden layers in the Transformer encoder. @@ -74,6 +86,12 @@ class TimesFMConfig(PretrainedConfig): def __init__( self, + patch_len: int = 32, + horizon_len: int = 128, + quantiles: List[float] = [0.1, 0.25, 0.5, 0.75, 0.9], + pad_val: float = 1123581321.0, + tolerance: float = 1e-6, + freq_size=3, d_model=1280, d_kv=80, d_ff=1280, @@ -93,6 +111,12 @@ def __init__( classifier_dropout=0.0, **kwargs, ): + self.patch_len = patch_len + self.horizon_len = horizon_len + self.quantiles = quantiles + self.pad_val = pad_val + self.tolerance = tolerance + self.freq_size = freq_size self.d_model = d_model self.d_kv = d_kv self.d_ff = d_ff diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 8b338d7ae153..61010b17929d 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -271,6 +271,24 @@ def forward(self, hidden_states): ALL_LAYERNORM_LAYERS.append(TimesFMLayerNorm) +class TimesFMResidualBlock(nn.Module): + def __init__(self, input_dims, hidden_dims, output_dims, dropout=0.1): + super().__init__() + + self.hidden_layer = nn.Sequential(nn.Linear(input_dims, hidden_dims), nn.SiLU()) + self.output_layer = nn.Linear(hidden_dims, output_dims) + self.residual_layer = nn.Linear(input_dims, output_dims) + self.dropout = nn.Dropout(dropout) + + def forward(self, inputs): + hidden = self.hidden_layer(inputs) + output = self.output_layer(hidden) + output = self.dropout(output) + residual = self.residual_layer(inputs) + + return output + residual + + class TimesFMPositionalEmbedding(nn.Module): """Generates position embedding for a given 1-d sequence. @@ -932,7 +950,6 @@ def _shift_right(self, input_ids): return shifted_input_ids -# Copied from transformers.models.t5.modeling_t5.T5Stack with T5->TimesFM class TimesFMStack(TimesFMPreTrainedModel): def __init__(self, config, embed_tokens=None): super().__init__(config) @@ -943,7 +960,6 @@ def __init__(self, config, embed_tokens=None): self.block = nn.ModuleList( [TimesFMBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] ) - self.final_layer_norm = TimesFMLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) # Initialize weights and apply final processing @@ -1182,7 +1198,6 @@ def forward( if i == v[-1] and "cuda:" + str(k) != self.last_device: hidden_states = hidden_states.to("cuda:" + str(k + 1)) - hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) # Add last layer @@ -1382,7 +1397,13 @@ class TimesFMModel(TimesFMPreTrainedModel): def __init__(self, config: TimesFMConfig): super().__init__(config) - self.shared = nn.Embedding(config.vocab_size, config.d_model) + self.freq_emb = nn.Embedding(config.freq_size, config.d_model) + self.horizon_ff_layer = TimesFMResidualBlock( + input_dims=config.hidden_size, + output_dims=config.horizon_len * (1 + len(config.quantiles)), + hidden_dims=config.d_ff, + dropout=config.dropout_rate, + ) encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False @@ -1909,7 +1930,13 @@ class TimesFMEncoderModel(TimesFMPreTrainedModel): def __init__(self, config: TimesFMConfig): super().__init__(config) - self.shared = nn.Embedding(config.vocab_size, config.d_model) + self.freq_emb = nn.Embedding(config.freq_size, config.d_model) + self.horizon_ff_layer = TimesFMResidualBlock( + input_dims=config.d_model, + output_dims=config.horizon_len * (1 + len(config.quantiles)), + hidden_dims=config.d_ff, + dropout=config.dropout_rate, + ) encoder_config = copy.deepcopy(config) encoder_config.use_cache = False From 3be589342478a79db0aa36ed0dfaecc2222ae0c3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 31 Aug 2024 22:02:11 +0200 Subject: [PATCH 088/242] fix input_dims --- .../models/timesfm/modeling_timesfm.py | 105 +----------------- 1 file changed, 2 insertions(+), 103 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 61010b17929d..1a2d93b67694 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -951,10 +951,9 @@ def _shift_right(self, input_ids): class TimesFMStack(TimesFMPreTrainedModel): - def __init__(self, config, embed_tokens=None): + def __init__(self, config): super().__init__(config) - self.embed_tokens = embed_tokens self.is_decoder = config.is_decoder self.block = nn.ModuleList( @@ -965,60 +964,9 @@ def __init__(self, config, embed_tokens=None): # Initialize weights and apply final processing self.post_init() # Model parallel - self.model_parallel = False self.device_map = None self.gradient_checkpointing = False - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def parallelize(self, device_map=None): - warnings.warn( - "`TimesFMStack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model" - " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" - " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0," - " 'block.1': 1, ...}", - FutureWarning, - ) - # Check validity of device_map - self.device_map = ( - get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map - ) - assert_device_map(self.device_map, len(self.block)) - self.model_parallel = True - self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) - self.last_device = "cuda:" + str(max(self.device_map.keys())) - # Load onto devices - for k, v in self.device_map.items(): - for layer in v: - cuda_device = "cuda:" + str(k) - self.block[layer] = self.block[layer].to(cuda_device) - - # Set embed_tokens to first layer - self.embed_tokens = self.embed_tokens.to(self.first_device) - # Set final layer norm to last device - self.final_layer_norm = self.final_layer_norm.to(self.last_device) - - @add_start_docstrings(DEPARALLELIZE_DOCSTRING) - def deparallelize(self): - warnings.warn( - "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", - FutureWarning, - ) - self.model_parallel = False - self.device_map = None - self.first_device = "cpu" - self.last_device = "cpu" - for i in range(len(self.block)): - self.block[i] = self.block[i].to("cpu") - self.embed_tokens = self.embed_tokens.to("cpu") - self.final_layer_norm = self.final_layer_norm.to("cpu") - torch.cuda.empty_cache() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, new_embeddings): - self.embed_tokens = new_embeddings - def forward( self, input_ids=None, @@ -1399,7 +1347,7 @@ def __init__(self, config: TimesFMConfig): super().__init__(config) self.freq_emb = nn.Embedding(config.freq_size, config.d_model) self.horizon_ff_layer = TimesFMResidualBlock( - input_dims=config.hidden_size, + input_dims=config.d_model, output_dims=config.horizon_len * (1 + len(config.quantiles)), hidden_dims=config.d_ff, dropout=config.dropout_rate, @@ -1950,58 +1898,9 @@ def __init__(self, config: TimesFMConfig): self.model_parallel = False self.device_map = None - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def parallelize(self, device_map=None): - warnings.warn( - "`TimesFMEncoderModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load" - " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" - " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0," - " 'block.1': 1, ...}", - FutureWarning, - ) - self.device_map = ( - get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) - if device_map is None - else device_map - ) - assert_device_map(self.device_map, len(self.encoder.block)) - self.encoder.parallelize(self.device_map) - self.model_parallel = True - - @add_start_docstrings(DEPARALLELIZE_DOCSTRING) - def deparallelize(self): - warnings.warn( - "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", - FutureWarning, - ) - self.encoder.deparallelize() - self.encoder = self.encoder.to("cpu") - self.model_parallel = False - self.device_map = None - torch.cuda.empty_cache() - - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, new_embeddings): - self.shared = new_embeddings - self.encoder.set_input_embeddings(new_embeddings) - - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) - def get_encoder(self): return self.encoder - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads) - @add_start_docstrings_to_model_forward(TIMESFM_ENCODER_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) def forward( From 9fb8bf84086ea66dcfd94085f0f1862bfdfb69b4 Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Thu, 5 Sep 2024 14:41:53 -0700 Subject: [PATCH 089/242] standardize code format with black --- src/transformers/models/timesfm/__init__.py | 4 +- .../models/timesfm/configuration_timesfm.py | 16 +- ...mesfm_original_tf_checkpoint_to_pytorch.py | 26 +- .../convert_timesfmx_checkpoint_to_flax.py | 320 +++++++------ .../convert_timesfmx_checkpoint_to_pytorch.py | 79 +++- .../models/timesfm/modeling_timesfm.py | 434 +++++++++++++----- 6 files changed, 619 insertions(+), 260 deletions(-) diff --git a/src/transformers/models/timesfm/__init__.py b/src/transformers/models/timesfm/__init__.py index 7398dccbda88..1abef3d3e175 100644 --- a/src/transformers/models/timesfm/__init__.py +++ b/src/transformers/models/timesfm/__init__.py @@ -63,4 +63,6 @@ else: import sys - sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + sys.modules[__name__] = _LazyModule( + __name__, globals()["__file__"], _import_structure, module_spec=__spec__ + ) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 933842f977ad..cdee64d7f377 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -82,7 +82,11 @@ class TimesFMConfig(PretrainedConfig): model_type = "timesfm" keys_to_ignore_at_inference = ["past_key_values"] - attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + } def __init__( self, @@ -167,10 +171,16 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: if self.use_past: common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence" common_inputs["decoder_input_ids"] = {0: "batch"} - common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} + common_inputs["decoder_attention_mask"] = { + 0: "batch", + 1: "past_decoder_sequence + sequence", + } else: common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} - common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} + common_inputs["decoder_attention_mask"] = { + 0: "batch", + 1: "decoder_sequence", + } if self.use_past: self.fill_with_past_key_values_(common_inputs, direction="inputs") diff --git a/src/transformers/models/timesfm/convert_timesfm_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/timesfm/convert_timesfm_original_tf_checkpoint_to_pytorch.py index aa66a8392d4f..b1ce727cac0c 100644 --- a/src/transformers/models/timesfm/convert_timesfm_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/timesfm/convert_timesfm_original_tf_checkpoint_to_pytorch.py @@ -16,14 +16,20 @@ import argparse -from transformers import TimesFMConfig, TimesFMForConditionalGeneration, load_tf_weights_in_timesfm +from transformers import ( + TimesFMConfig, + TimesFMForConditionalGeneration, + load_tf_weights_in_timesfm, +) from transformers.utils import logging logging.set_verbosity_info() -def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): +def convert_tf_checkpoint_to_pytorch( + tf_checkpoint_path, config_file, pytorch_dump_path +): # Initialise PyTorch model config = TimesFMConfig.from_json_file(config_file) print(f"Building PyTorch model from configuration: {config}") @@ -41,7 +47,11 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du parser = argparse.ArgumentParser() # Required parameters parser.add_argument( - "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + "--tf_checkpoint_path", + default=None, + type=str, + required=True, + help="Path to the TensorFlow checkpoint path.", ) parser.add_argument( "--config_file", @@ -53,7 +63,13 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du ), ) parser.add_argument( - "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + "--pytorch_dump_path", + default=None, + type=str, + required=True, + help="Path to the output PyTorch model.", ) args = parser.parse_args() - convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) + convert_tf_checkpoint_to_pytorch( + args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path + ) diff --git a/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_flax.py b/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_flax.py index 98570e22876e..f9468ffb84c6 100644 --- a/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_flax.py +++ b/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_flax.py @@ -22,7 +22,9 @@ from transformers import FlaxTimesFMForConditionalGeneration, TimesFMConfig -def convert_timesfmx_checkpoint_to_flax(timesfmx_checkpoint_path, config_name, flax_dump_folder_path): +def convert_timesfmx_checkpoint_to_flax( + timesfmx_checkpoint_path, config_name, flax_dump_folder_path +): config = TimesFMConfig.from_pretrained(config_name) flax_model = FlaxTimesFMForConditionalGeneration(config=config) timesfmx_model = checkpoints.load_timesfmx_checkpoint(timesfmx_checkpoint_path) @@ -34,67 +36,89 @@ def convert_timesfmx_checkpoint_to_flax(timesfmx_checkpoint_path, config_name, f layer_name = f"layers_{str(layer_index)}" # Self-Attention - timesfmx_attention_key = timesfmx_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"] - timesfmx_attention_out = timesfmx_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"] - timesfmx_attention_query = timesfmx_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"] - timesfmx_attention_value = timesfmx_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"] + timesfmx_attention_key = timesfmx_model["target"]["encoder"][layer_name][ + "attention" + ]["key"]["kernel"] + timesfmx_attention_out = timesfmx_model["target"]["encoder"][layer_name][ + "attention" + ]["out"]["kernel"] + timesfmx_attention_query = timesfmx_model["target"]["encoder"][layer_name][ + "attention" + ]["query"]["kernel"] + timesfmx_attention_value = timesfmx_model["target"]["encoder"][layer_name][ + "attention" + ]["value"]["kernel"] # Layer Normalization - timesfmx_attention_layer_norm = timesfmx_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"] + timesfmx_attention_layer_norm = timesfmx_model["target"]["encoder"][layer_name][ + "pre_attention_layer_norm" + ]["scale"] if split_mlp_wi: - timesfmx_mlp_wi_0 = timesfmx_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"] - timesfmx_mlp_wi_1 = timesfmx_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"] + timesfmx_mlp_wi_0 = timesfmx_model["target"]["encoder"][layer_name]["mlp"][ + "wi_0" + ]["kernel"] + timesfmx_mlp_wi_1 = timesfmx_model["target"]["encoder"][layer_name]["mlp"][ + "wi_1" + ]["kernel"] else: - timesfmx_mlp_wi = timesfmx_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"] + timesfmx_mlp_wi = timesfmx_model["target"]["encoder"][layer_name]["mlp"][ + "wi" + ]["kernel"] - timesfmx_mlp_wo = timesfmx_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"] + timesfmx_mlp_wo = timesfmx_model["target"]["encoder"][layer_name]["mlp"]["wo"][ + "kernel" + ] # Layer Normalization - timesfmx_mlp_layer_norm = timesfmx_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"] + timesfmx_mlp_layer_norm = timesfmx_model["target"]["encoder"][layer_name][ + "pre_mlp_layer_norm" + ]["scale"] # Assigning - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = ( - timesfmx_attention_key - ) - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = ( - timesfmx_attention_out - ) - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = ( - timesfmx_attention_query - ) - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = ( - timesfmx_attention_value - ) - - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = ( - timesfmx_attention_layer_norm - ) + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ + "SelfAttention" + ]["k"]["kernel"] = timesfmx_attention_key + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ + "SelfAttention" + ]["o"]["kernel"] = timesfmx_attention_out + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ + "SelfAttention" + ]["q"]["kernel"] = timesfmx_attention_query + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ + "SelfAttention" + ]["v"]["kernel"] = timesfmx_attention_value + + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ + "layer_norm" + ]["weight"] = timesfmx_attention_layer_norm if split_mlp_wi: - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_0"][ - "kernel" - ] = timesfmx_mlp_wi_0 - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_1"][ - "kernel" - ] = timesfmx_mlp_wi_1 + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"][ + "DenseReluDense" + ]["wi_0"]["kernel"] = timesfmx_mlp_wi_0 + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"][ + "DenseReluDense" + ]["wi_1"]["kernel"] = timesfmx_mlp_wi_1 else: - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi"]["kernel"] = ( - timesfmx_mlp_wi - ) + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"][ + "DenseReluDense" + ]["wi"]["kernel"] = timesfmx_mlp_wi - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wo"]["kernel"] = ( - timesfmx_mlp_wo - ) - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = ( - timesfmx_mlp_layer_norm - ) + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"][ + "DenseReluDense" + ]["wo"]["kernel"] = timesfmx_mlp_wo + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"][ + "layer_norm" + ]["weight"] = timesfmx_mlp_layer_norm # Only for layer 0: - timesfmx_encoder_rel_embedding = timesfmx_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T - flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][ - "embedding" - ] = timesfmx_encoder_rel_embedding + timesfmx_encoder_rel_embedding = timesfmx_model["target"]["encoder"]["relpos_bias"][ + "rel_embedding" + ].T + flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"][ + "relative_attention_bias" + ]["embedding"] = timesfmx_encoder_rel_embedding # Assigning timesfmx_encoder_norm = timesfmx_model["target"]["encoder"]["encoder_norm"]["scale"] @@ -105,109 +129,131 @@ def convert_timesfmx_checkpoint_to_flax(timesfmx_checkpoint_path, config_name, f layer_name = f"layers_{str(layer_index)}" # Self-Attention - timesfmx_attention_key = timesfmx_model["target"]["decoder"][layer_name]["self_attention"]["key"]["kernel"] - timesfmx_attention_out = timesfmx_model["target"]["decoder"][layer_name]["self_attention"]["out"]["kernel"] - timesfmx_attention_query = timesfmx_model["target"]["decoder"][layer_name]["self_attention"]["query"]["kernel"] - timesfmx_attention_value = timesfmx_model["target"]["decoder"][layer_name]["self_attention"]["value"]["kernel"] + timesfmx_attention_key = timesfmx_model["target"]["decoder"][layer_name][ + "self_attention" + ]["key"]["kernel"] + timesfmx_attention_out = timesfmx_model["target"]["decoder"][layer_name][ + "self_attention" + ]["out"]["kernel"] + timesfmx_attention_query = timesfmx_model["target"]["decoder"][layer_name][ + "self_attention" + ]["query"]["kernel"] + timesfmx_attention_value = timesfmx_model["target"]["decoder"][layer_name][ + "self_attention" + ]["value"]["kernel"] # Layer Normalization - timesfmx_pre_attention_layer_norm = timesfmx_model["target"]["decoder"][layer_name]["pre_self_attention_layer_norm"][ - "scale" - ] + timesfmx_pre_attention_layer_norm = timesfmx_model["target"]["decoder"][ + layer_name + ]["pre_self_attention_layer_norm"]["scale"] # Encoder-Decoder-Attention - timesfmx_enc_dec_attention_key = timesfmx_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["key"][ - "kernel" - ] - timesfmx_enc_dec_attention_out = timesfmx_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["out"][ - "kernel" - ] - timesfmx_enc_dec_attention_query = timesfmx_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["query"][ - "kernel" - ] - timesfmx_enc_dec_attention_value = timesfmx_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["value"][ - "kernel" - ] + timesfmx_enc_dec_attention_key = timesfmx_model["target"]["decoder"][ + layer_name + ]["encoder_decoder_attention"]["key"]["kernel"] + timesfmx_enc_dec_attention_out = timesfmx_model["target"]["decoder"][ + layer_name + ]["encoder_decoder_attention"]["out"]["kernel"] + timesfmx_enc_dec_attention_query = timesfmx_model["target"]["decoder"][ + layer_name + ]["encoder_decoder_attention"]["query"]["kernel"] + timesfmx_enc_dec_attention_value = timesfmx_model["target"]["decoder"][ + layer_name + ]["encoder_decoder_attention"]["value"]["kernel"] # Layer Normalization - timesfmx_cross_layer_norm = timesfmx_model["target"]["decoder"][layer_name]["pre_cross_attention_layer_norm"]["scale"] + timesfmx_cross_layer_norm = timesfmx_model["target"]["decoder"][layer_name][ + "pre_cross_attention_layer_norm" + ]["scale"] # MLP if split_mlp_wi: - timesfmx_mlp_wi_0 = timesfmx_model["target"]["decoder"][layer_name]["mlp"]["wi_0"]["kernel"] - timesfmx_mlp_wi_1 = timesfmx_model["target"]["decoder"][layer_name]["mlp"]["wi_1"]["kernel"] + timesfmx_mlp_wi_0 = timesfmx_model["target"]["decoder"][layer_name]["mlp"][ + "wi_0" + ]["kernel"] + timesfmx_mlp_wi_1 = timesfmx_model["target"]["decoder"][layer_name]["mlp"][ + "wi_1" + ]["kernel"] else: - timesfmx_mlp_wi = timesfmx_model["target"]["decoder"][layer_name]["mlp"]["wi"]["kernel"] + timesfmx_mlp_wi = timesfmx_model["target"]["decoder"][layer_name]["mlp"][ + "wi" + ]["kernel"] - timesfmx_mlp_wo = timesfmx_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"] + timesfmx_mlp_wo = timesfmx_model["target"]["decoder"][layer_name]["mlp"]["wo"][ + "kernel" + ] # Layer Normalization - tx5_mlp_layer_norm = timesfmx_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"] + tx5_mlp_layer_norm = timesfmx_model["target"]["decoder"][layer_name][ + "pre_mlp_layer_norm" + ]["scale"] # Assigning - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = ( - timesfmx_attention_key - ) - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = ( - timesfmx_attention_out - ) - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = ( - timesfmx_attention_query - ) - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = ( - timesfmx_attention_value - ) - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = ( - timesfmx_pre_attention_layer_norm - ) - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["k"]["kernel"] = ( - timesfmx_enc_dec_attention_key - ) - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["o"]["kernel"] = ( - timesfmx_enc_dec_attention_out - ) - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["q"]["kernel"] = ( - timesfmx_enc_dec_attention_query - ) - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["v"]["kernel"] = ( - timesfmx_enc_dec_attention_value - ) - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = ( - timesfmx_cross_layer_norm - ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ + "SelfAttention" + ]["k"]["kernel"] = timesfmx_attention_key + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ + "SelfAttention" + ]["o"]["kernel"] = timesfmx_attention_out + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ + "SelfAttention" + ]["q"]["kernel"] = timesfmx_attention_query + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ + "SelfAttention" + ]["v"]["kernel"] = timesfmx_attention_value + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ + "layer_norm" + ]["weight"] = timesfmx_pre_attention_layer_norm + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ + "EncDecAttention" + ]["k"]["kernel"] = timesfmx_enc_dec_attention_key + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ + "EncDecAttention" + ]["o"]["kernel"] = timesfmx_enc_dec_attention_out + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ + "EncDecAttention" + ]["q"]["kernel"] = timesfmx_enc_dec_attention_query + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ + "EncDecAttention" + ]["v"]["kernel"] = timesfmx_enc_dec_attention_value + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ + "layer_norm" + ]["weight"] = timesfmx_cross_layer_norm if split_mlp_wi: - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_0"][ - "kernel" - ] = timesfmx_mlp_wi_0 - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_1"][ - "kernel" - ] = timesfmx_mlp_wi_1 + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"][ + "DenseReluDense" + ]["wi_0"]["kernel"] = timesfmx_mlp_wi_0 + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"][ + "DenseReluDense" + ]["wi_1"]["kernel"] = timesfmx_mlp_wi_1 else: - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi"]["kernel"] = ( - timesfmx_mlp_wi - ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"][ + "DenseReluDense" + ]["wi"]["kernel"] = timesfmx_mlp_wi - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wo"]["kernel"] = ( - timesfmx_mlp_wo - ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"][ + "DenseReluDense" + ]["wo"]["kernel"] = timesfmx_mlp_wo - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["layer_norm"]["weight"] = ( - tx5_mlp_layer_norm - ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"][ + "layer_norm" + ]["weight"] = tx5_mlp_layer_norm # Decoder Normalization tx5_decoder_norm = timesfmx_model["target"]["decoder"]["decoder_norm"]["scale"] flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm # Only for layer 0: - timesfmx_decoder_rel_embedding = timesfmx_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T - flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][ - "embedding" - ] = timesfmx_decoder_rel_embedding + timesfmx_decoder_rel_embedding = timesfmx_model["target"]["decoder"]["relpos_bias"][ + "rel_embedding" + ].T + flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"][ + "relative_attention_bias" + ]["embedding"] = timesfmx_decoder_rel_embedding # Token Embeddings tx5_token_embeddings = timesfmx_model["target"]["token_embedder"]["embedding"] @@ -215,7 +261,9 @@ def convert_timesfmx_checkpoint_to_flax(timesfmx_checkpoint_path, config_name, f # LM Head (only in v1.1 checkpoints) if "logits_dense" in timesfmx_model["target"]["decoder"]: - flax_model.params["lm_head"]["kernel"] = timesfmx_model["target"]["decoder"]["logits_dense"]["kernel"] + flax_model.params["lm_head"]["kernel"] = timesfmx_model["target"]["decoder"][ + "logits_dense" + ]["kernel"] flax_model.save_pretrained(flax_dump_folder_path) print("TimesFMX Model was sucessfully converted!") @@ -225,11 +273,27 @@ def convert_timesfmx_checkpoint_to_flax(timesfmx_checkpoint_path, config_name, f parser = argparse.ArgumentParser() # Required parameters parser.add_argument( - "--timesfmx_checkpoint_path", default=None, type=str, required=True, help="Path the TX5 checkpoint." + "--timesfmx_checkpoint_path", + default=None, + type=str, + required=True, + help="Path the TX5 checkpoint.", + ) + parser.add_argument( + "--config_name", + default=None, + type=str, + required=True, + help="Config name of TimesFM model.", ) - parser.add_argument("--config_name", default=None, type=str, required=True, help="Config name of TimesFM model.") parser.add_argument( - "--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model." + "--flax_dump_folder_path", + default=None, + type=str, + required=True, + help="Path to the output FLAX model.", ) args = parser.parse_args() - convert_timesfmx_checkpoint_to_flax(args.timesfmx_checkpoint_path, args.config_name, args.flax_dump_folder_path) + convert_timesfmx_checkpoint_to_flax( + args.timesfmx_checkpoint_path, args.config_name, args.flax_dump_folder_path + ) diff --git a/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_pytorch.py b/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_pytorch.py index b761d76bbdcd..8d5f13535e8d 100644 --- a/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_pytorch.py +++ b/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_pytorch.py @@ -35,7 +35,11 @@ from flax import traverse_util from timesfmx import checkpoints -from transformers import TimesFMConfig, TimesFMEncoderModel, TimesFMForConditionalGeneration +from transformers import ( + TimesFMConfig, + TimesFMEncoderModel, + TimesFMForConditionalGeneration, +) from transformers.utils import logging @@ -69,7 +73,9 @@ def timesfmx_layer_norm_lookup(params, i, prefix, layer_name): return params[f"{prefix}/layers_{i}/{layer_name}/scale"] -def convert_timesfmx_to_pytorch(variables: dict, *, num_layers: int, num_decoder_layers: int, is_encoder_only: bool): +def convert_timesfmx_to_pytorch( + variables: dict, *, num_layers: int, num_decoder_layers: int, is_encoder_only: bool +): """Converts the parameters from TimesFMX-Flax to Transformers-PyTorch.""" old = traverse_util.flatten_dict(variables["target"]) old = {"/".join(k): v for k, v in old.items()} @@ -86,7 +92,9 @@ def convert_timesfmx_to_pytorch(variables: dict, *, num_layers: int, num_decoder # Encoder. for i in range(num_layers): # Block i, layer 0 (Self Attention). - layer_norm = timesfmx_layer_norm_lookup(old, i, "encoder", "pre_attention_layer_norm") + layer_norm = timesfmx_layer_norm_lookup( + old, i, "encoder", "pre_attention_layer_norm" + ) k, o, q, v = timesfmx_attention_lookup(old, i, "encoder", "attention") new[f"encoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm new[f"encoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T @@ -114,7 +122,9 @@ def convert_timesfmx_to_pytorch(variables: dict, *, num_layers: int, num_decoder # Decoder. for i in range(num_decoder_layers): # Block i, layer 0 (Self Attention). - layer_norm = timesfmx_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm") + layer_norm = timesfmx_layer_norm_lookup( + old, i, "decoder", "pre_self_attention_layer_norm" + ) k, o, q, v = timesfmx_attention_lookup(old, i, "decoder", "self_attention") new[f"decoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm new[f"decoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T @@ -123,8 +133,12 @@ def convert_timesfmx_to_pytorch(variables: dict, *, num_layers: int, num_decoder new[f"decoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T # Block i, layer 1 (Cross Attention). - layer_norm = timesfmx_layer_norm_lookup(old, i, "decoder", "pre_cross_attention_layer_norm") - k, o, q, v = timesfmx_attention_lookup(old, i, "decoder", "encoder_decoder_attention") + layer_norm = timesfmx_layer_norm_lookup( + old, i, "decoder", "pre_cross_attention_layer_norm" + ) + k, o, q, v = timesfmx_attention_lookup( + old, i, "decoder", "encoder_decoder_attention" + ) new[f"decoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm new[f"decoder.block.{i}.layer.1.EncDecAttention.k.weight"] = k.T new[f"decoder.block.{i}.layer.1.EncDecAttention.o.weight"] = o.T @@ -132,7 +146,9 @@ def convert_timesfmx_to_pytorch(variables: dict, *, num_layers: int, num_decoder new[f"decoder.block.{i}.layer.1.EncDecAttention.v.weight"] = v.T # Block i, layer 2 (MLP). - layer_norm = timesfmx_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm") + layer_norm = timesfmx_layer_norm_lookup( + old, i, "decoder", "pre_mlp_layer_norm" + ) wi, wo = timesfmx_mlp_lookup(old, i, "decoder", split_mlp_wi) new[f"decoder.block.{i}.layer.2.layer_norm.weight"] = layer_norm if split_mlp_wi: @@ -143,9 +159,9 @@ def convert_timesfmx_to_pytorch(variables: dict, *, num_layers: int, num_decoder new[f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight"] = wo.T new["decoder.final_layer_norm.weight"] = old["decoder/decoder_norm/scale"] - new["decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = old[ - "decoder/relpos_bias/rel_embedding" - ].T + new["decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = ( + old["decoder/relpos_bias/rel_embedding"].T + ) # LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead) if "decoder/logits_dense/kernel" in old: @@ -157,7 +173,9 @@ def convert_timesfmx_to_pytorch(variables: dict, *, num_layers: int, num_decoder def make_state_dict(converted_params, is_encoder_only: bool): """Prepares a state dict for the PyTorch model.""" # Make a state dict with torch tensors. - state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()]) + state_dict = collections.OrderedDict( + [(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()] + ) # Add what is missing. if "encoder.embed_tokens.weight" not in state_dict: @@ -174,7 +192,9 @@ def make_state_dict(converted_params, is_encoder_only: bool): return state_dict -def load_timesfmx_weights_in_timesfm(model, config, timesfmx_checkpoint_path, is_encoder_only): +def load_timesfmx_weights_in_timesfm( + model, config, timesfmx_checkpoint_path, is_encoder_only +): """Replaces the params in model witht the TimesFMX converted params.""" variables = checkpoints.load_timesfmx_checkpoint(timesfmx_checkpoint_path) converted = convert_timesfmx_to_pytorch( @@ -188,7 +208,10 @@ def load_timesfmx_weights_in_timesfm(model, config, timesfmx_checkpoint_path, is def convert_timesfmx_checkpoint_to_pytorch( - timesfmx_checkpoint_path, config_file, pytorch_dump_path, is_encoder_only: bool = False + timesfmx_checkpoint_path, + config_file, + pytorch_dump_path, + is_encoder_only: bool = False, ): """Loads the config and model, converts the TimesFMX checkpoint, and saves a PyTorch checkpoint.""" # Initialise PyTorch model @@ -202,7 +225,9 @@ def convert_timesfmx_checkpoint_to_pytorch( model = TimesFMForConditionalGeneration(config) # Load weights from tf checkpoint - load_timesfmx_weights_in_timesfm(model, config, timesfmx_checkpoint_path, is_encoder_only) + load_timesfmx_weights_in_timesfm( + model, config, timesfmx_checkpoint_path, is_encoder_only + ) # Save pytorch-model print(f"Save PyTorch model to {pytorch_dump_path}") @@ -214,10 +239,16 @@ def convert_timesfmx_checkpoint_to_pytorch( if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Converts a native TimesFMX checkpoint into a PyTorch checkpoint.") + parser = argparse.ArgumentParser( + description="Converts a native TimesFMX checkpoint into a PyTorch checkpoint." + ) # Required parameters parser.add_argument( - "--timesfmx_checkpoint_path", default=None, type=str, required=True, help="Path to the TimesFMX checkpoint." + "--timesfmx_checkpoint_path", + default=None, + type=str, + required=True, + help="Path to the TimesFMX checkpoint.", ) parser.add_argument( "--config_file", @@ -227,12 +258,22 @@ def convert_timesfmx_checkpoint_to_pytorch( help="The config json file corresponding to the pre-trained TimesFM model.\nThis specifies the model architecture.", ) parser.add_argument( - "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + "--pytorch_dump_path", + default=None, + type=str, + required=True, + help="Path to the output PyTorch model.", ) parser.add_argument( - "--is_encoder_only", action="store_true", help="Check if the model is encoder-decoder model", default=False + "--is_encoder_only", + action="store_true", + help="Check if the model is encoder-decoder model", + default=False, ) args = parser.parse_args() convert_timesfmx_checkpoint_to_pytorch( - args.timesfmx_checkpoint_path, args.config_file, args.pytorch_dump_path, args.is_encoder_only + args.timesfmx_checkpoint_path, + args.config_file, + args.pytorch_dump_path, + args.is_encoder_only, ) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 1a2d93b67694..8542c54277bc 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -36,7 +36,11 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + find_pruneable_heads_and_indices, + prune_linear_layer, +) from ...utils import ( DUMMY_INPUTS, DUMMY_MASK, @@ -96,7 +100,14 @@ def load_tf_weights_in_timesfm(model, config, tf_checkpoint_path): # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # which are not required for using pretrained model if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + n + in [ + "adam_v", + "adam_m", + "AdamWeightDecayOptimizer", + "AdamWeightDecayOptimizer_1", + "global_step", + ] for n in name ): logger.info(f"Skipping {'/'.join(name)}") @@ -140,7 +151,11 @@ def load_tf_weights_in_timesfm(model, config, tf_checkpoint_path): continue elif scope_names[0] == "logits": pointer = getattr(pointer, "lm_head") - elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): + elif ( + scope_names[0] == "wi" + and len(scope_names) > 1 + and scope_names[1].isdigit() + ): pointer = getattr(pointer, f"wi_{scope_names[1]}") continue else: @@ -159,7 +174,9 @@ def load_tf_weights_in_timesfm(model, config, tf_checkpoint_path): array = np.transpose(array) try: if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + raise ValueError( + f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + ) except AssertionError as e: e.args += (pointer.shape, array.shape) raise @@ -260,12 +277,16 @@ def forward(self, hidden_states): TimesFMLayerNorm = FusedRMSNorm # noqa - logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of TimesFMLayerNorm") + logger.info( + "Discovered apex.normalization.FusedRMSNorm - will use it instead of TimesFMLayerNorm" + ) except ImportError: # using the normal TimesFMLayerNorm pass except Exception: - logger.warning("discovered apex but it failed to load, falling back to TimesFMLayerNorm") + logger.warning( + "discovered apex but it failed to load, falling back to TimesFMLayerNorm" + ) pass ALL_LAYERNORM_LAYERS.append(TimesFMLayerNorm) @@ -326,17 +347,21 @@ def forward(self, seq_length=None, position=None): position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0) else: if position.ndim != 2: - raise ValueError(f"position should have 2 dimensions, got {position.ndim}") + raise ValueError( + f"position should have 2 dimensions, got {position.ndim}" + ) num_timescales = self.embedding_dims // 2 - log_timescale_increment = math.log(float(self.max_timescale) / float(self.min_timescale)) / max( - torch.tensor(num_timescales, dtype=torch.float32) - 1, 1 - ) + log_timescale_increment = math.log( + float(self.max_timescale) / float(self.min_timescale) + ) / max(torch.tensor(num_timescales, dtype=torch.float32) - 1, 1) inv_timescales = self.min_timescale * torch.exp( torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment ) scaled_time = position.unsqueeze(2) * inv_timescales.unsqueeze(0).unsqueeze(0) - signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2).type(torch.float32) + signal = torch.cat( + [torch.sin(scaled_time), torch.cos(scaled_time)], dim=2 + ).type(torch.float32) signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) return signal @@ -413,7 +438,9 @@ def __init__(self, config: TimesFMConfig, has_relative_attention_bias=False): self.per_head_dim_scale = TimesFMPerHeadDimScale(config) if self.has_relative_attention_bias: - self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.relative_attention_bias = nn.Embedding( + self.relative_attention_num_buckets, self.n_heads + ) self.pruned_heads = set() self.gradient_checkpointing = False @@ -434,7 +461,9 @@ def prune_heads(self, heads): self.pruned_heads = self.pruned_heads.union(heads) @staticmethod - def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 + ): """ Adapted from Mesh Tensorflow: https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 @@ -461,7 +490,9 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets relative_buckets += (relative_position > 0).to(torch.long) * num_buckets relative_position = torch.abs(relative_position) else: - relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + relative_position = -torch.min( + relative_position, torch.zeros_like(relative_position) + ) # now relative_position is in the range [0, inf) # half of the buckets are for exact increments in positions @@ -475,27 +506,40 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets * (num_buckets - max_exact) ).to(torch.long) relative_position_if_large = torch.min( - relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1), ) - relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + relative_buckets += torch.where( + is_small, relative_position, relative_position_if_large + ) return relative_buckets def compute_bias(self, query_length, key_length, device=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] - memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] - relative_position = memory_position - context_position # shape (query_length, key_length) + context_position = torch.arange(query_length, dtype=torch.long, device=device)[ + :, None + ] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[ + None, : + ] + relative_position = ( + memory_position - context_position + ) # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( relative_position, # shape (query_length, key_length) bidirectional=(not self.is_decoder), num_buckets=self.relative_attention_num_buckets, max_distance=self.relative_attention_max_distance, ) - values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) - values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + values = self.relative_attention_bias( + relative_position_bucket + ) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze( + 0 + ) # shape (1, num_heads, query_length, key_length) return values def forward( @@ -525,17 +569,25 @@ def forward( raise ValueError( f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + real_seq_length += ( + past_key_value[0].shape[2] if query_length is None else query_length + ) - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + key_length = ( + real_seq_length if key_value_states is None else key_value_states.shape[1] + ) def shape(states): """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + return states.view( + batch_size, -1, self.n_heads, self.key_value_proj_dim + ).transpose(1, 2) def unshape(states): """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + return ( + states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + ) def project(hidden_states, proj_layer, key_value_states, past_key_value): """projects hidden states correctly to key/query states""" @@ -565,15 +617,23 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): return hidden_states # get query states - unscaled_query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + unscaled_query_states = shape( + self.q(hidden_states) + ) # (batch_size, n_heads, seq_length, dim_per_head) query_states = self.per_head_dim_scale(unscaled_query_states) # get key/value states key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + hidden_states, + self.k, + key_value_states, + past_key_value[0] if past_key_value is not None else None, ) value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + hidden_states, + self.v, + key_value_states, + past_key_value[1] if past_key_value is not None else None, ) # compute scores @@ -584,12 +644,16 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): if position_bias is None: if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, real_seq_length, key_length), + device=scores.device, + dtype=scores.dtype, ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device + ) # if key and values are already calculated # we want only the last query position bias @@ -597,7 +661,9 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias = position_bias[:, :, -hidden_states.size(1) :, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + position_bias = ( + position_bias + mask + ) # (batch_size, n_heads, seq_length, key_length) if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -618,10 +684,14 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = unshape( + torch.matmul(attn_weights, value_states) + ) # (batch_size, seq_length, dim) attn_output = self.o(attn_output) - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None + present_key_value_state = ( + (key_states, value_states) if (self.is_decoder and use_cache) else None + ) outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) if output_attentions: @@ -633,8 +703,12 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): class TimesFMLayerSelfAttention(nn.Module): def __init__(self, config, has_relative_attention_bias=False): super().__init__() - self.SelfAttention = TimesFMAttention(config, has_relative_attention_bias=has_relative_attention_bias) - self.layer_norm = TimesFMLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.SelfAttention = TimesFMAttention( + config, has_relative_attention_bias=has_relative_attention_bias + ) + self.layer_norm = TimesFMLayerNorm( + config.d_model, eps=config.layer_norm_epsilon + ) self.dropout = nn.Dropout(config.dropout_rate) def forward( @@ -658,7 +732,9 @@ def forward( output_attentions=output_attentions, ) hidden_states = hidden_states + self.dropout(attention_output[0]) - outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + outputs = (hidden_states,) + attention_output[ + 1: + ] # add attentions if we output them return outputs @@ -666,8 +742,12 @@ def forward( class TimesFMLayerCrossAttention(nn.Module): def __init__(self, config): super().__init__() - self.EncDecAttention = TimesFMAttention(config, has_relative_attention_bias=False) - self.layer_norm = TimesFMLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.EncDecAttention = TimesFMAttention( + config, has_relative_attention_bias=False + ) + self.layer_norm = TimesFMLayerNorm( + config.d_model, eps=config.layer_norm_epsilon + ) self.dropout = nn.Dropout(config.dropout_rate) def forward( @@ -695,7 +775,9 @@ def forward( output_attentions=output_attentions, ) layer_output = hidden_states + self.dropout(attention_output[0]) - outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + outputs = (layer_output,) + attention_output[ + 1: + ] # add attentions if we output them return outputs @@ -705,7 +787,11 @@ def __init__(self, config, has_relative_attention_bias=False): super().__init__() self.is_decoder = config.is_decoder self.layer = nn.ModuleList() - self.layer.append(TimesFMLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + self.layer.append( + TimesFMLayerSelfAttention( + config, has_relative_attention_bias=has_relative_attention_bias + ) + ) if self.is_decoder: self.layer.append(TimesFMLayerCrossAttention(config)) @@ -728,7 +814,9 @@ def forward( ): if past_key_value is not None: if not self.is_decoder: - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") + logger.warning( + "`past_key_values` is passed to the encoder. Please make sure this is intended." + ) expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 if len(past_key_value) != expected_num_past_key_values: @@ -753,7 +841,9 @@ def forward( output_attentions=output_attentions, ) hidden_states, present_key_value_state = self_attention_outputs[:2] - attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + attention_outputs = self_attention_outputs[ + 2: + ] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -762,7 +852,9 @@ def forward( torch.finfo(hidden_states.dtype).max - 1000, torch.finfo(hidden_states.dtype).max, ) - hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: @@ -793,11 +885,15 @@ def forward( torch.finfo(hidden_states.dtype).max - 1000, torch.finfo(hidden_states.dtype).max, ) - hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) # Combine self attn and cross attn key value states if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] + present_key_value_state = ( + present_key_value_state + cross_attention_outputs[1] + ) # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] @@ -812,7 +908,9 @@ def forward( torch.finfo(hidden_states.dtype).max - 1000, torch.finfo(hidden_states.dtype).max, ) - hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) outputs = (hidden_states,) @@ -871,12 +969,19 @@ def dummy_inputs(self): def _init_weights(self, module): """Initialize the weights""" - factor = self.config.initializer_factor # Used for testing weights initialization + factor = ( + self.config.initializer_factor + ) # Used for testing weights initialization if isinstance(module, TimesFMLayerNorm): module.weight.data.fill_(factor * 1.0) elif isinstance( module, - (TimesFMModel, TimesFMForConditionalGeneration, TimesFMEncoderModel, TimesFMForQuestionAnswering), + ( + TimesFMModel, + TimesFMForConditionalGeneration, + TimesFMEncoderModel, + TimesFMForQuestionAnswering, + ), ): # Mesh TensorFlow embeddings initialization # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 @@ -884,27 +989,37 @@ def _init_weights(self, module): if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "qa_outputs"): - module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.qa_outputs.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model) ** -0.5) + ) module.qa_outputs.bias.data.zero_() elif isinstance(module, TimesFMForTokenClassification): if hasattr(module, "classifier"): module.classifier.weight.data.normal_(mean=0.0, std=factor * 1.0) module.classifier.bias.data.zero_() elif isinstance(module, TimesFMClassificationHead): - module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.dense.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model) ** -0.5) + ) if hasattr(module.dense, "bias") and module.dense.bias is not None: module.dense.bias.data.zero_() - module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.out_proj.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model) ** -0.5) + ) if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: module.out_proj.bias.data.zero_() elif isinstance(module, TimesFMDenseActDense): # Mesh TensorFlow FF initialization # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 - module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model) ** -0.5) + ) if hasattr(module.wi, "bias") and module.wi.bias is not None: module.wi.bias.data.zero_() - module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + module.wo.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_ff) ** -0.5) + ) if hasattr(module.wo, "bias") and module.wo.bias is not None: module.wo.bias.data.zero_() elif isinstance(module, TimesFMPerHeadDimScale): @@ -915,12 +1030,18 @@ def _init_weights(self, module): d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.q.weight.data.normal_( + mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5) + ) module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + module.o.weight.data.normal_( + mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5) + ) if module.has_relative_attention_bias: - module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + module.relative_attention_bias.weight.data.normal_( + mean=0.0, std=factor * ((d_model) ** -0.5) + ) def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -935,8 +1056,12 @@ def _shift_right(self, input_ids): # shift inputs to the right if is_torch_fx_proxy(input_ids): # Item assignment is not supported natively for proxies. - shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) - shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + shifted_input_ids = torch.full( + input_ids.shape[:-1] + (1,), decoder_start_token_id + ) + shifted_input_ids = torch.cat( + [shifted_input_ids, input_ids[..., :-1]], dim=-1 + ) else: shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() @@ -957,7 +1082,10 @@ def __init__(self, config): self.is_decoder = config.is_decoder self.block = nn.ModuleList( - [TimesFMBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + [ + TimesFMBlock(config, has_relative_attention_bias=bool(i == 0)) + for i in range(config.num_layers) + ] ) self.dropout = nn.Dropout(config.dropout_rate) @@ -987,11 +1115,19 @@ def forward( torch.cuda.set_device(self.first_device) self.embed_tokens = self.embed_tokens.to(self.first_device) use_cache = use_cache if use_cache is not None else self.config.use_cache - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: err_msg_prefix = "decoder_" if self.is_decoder else "" @@ -1005,43 +1141,61 @@ def forward( input_shape = inputs_embeds.size()[:-1] else: err_msg_prefix = "decoder_" if self.is_decoder else "" - raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + raise ValueError( + f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds" + ) if inputs_embeds is None: if self.embed_tokens is None: - raise ValueError("You have to initialize the model with valid token embeddings") + raise ValueError( + "You have to initialize the model with valid token embeddings" + ) inputs_embeds = self.embed_tokens(input_ids) batch_size, seq_length = input_shape # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + mask_seq_length = ( + past_key_values[0][0].shape[2] + seq_length + if past_key_values is not None + else seq_length + ) if use_cache is True: if not self.is_decoder: - raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") + raise ValueError( + f"`use_cache` can only be set to `True` if {self} is used as a decoder" + ) # initialize past_key_values with `None` if past does not exist if past_key_values is None: past_key_values = [None] * len(self.block) if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + attention_mask = torch.ones( + batch_size, mask_seq_length, device=inputs_embeds.device + ) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape + ) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_batch_size, encoder_sequence_length, _ = ( + encoder_hidden_states.size() + ) encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones( encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long ) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) else: encoder_extended_attention_mask = None @@ -1054,7 +1208,9 @@ def forward( # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) - cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask( + cross_attn_head_mask, self.config.num_layers + ) present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -1064,7 +1220,9 @@ def forward( hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, (layer_module, past_key_value) in enumerate( + zip(self.block, past_key_values) + ): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] # Model parallel @@ -1076,15 +1234,23 @@ def forward( if position_bias is not None: position_bias = position_bias.to(hidden_states.device) if encoder_hidden_states is not None: - encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + encoder_hidden_states = encoder_hidden_states.to( + hidden_states.device + ) if encoder_extended_attention_mask is not None: - encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + encoder_extended_attention_mask = ( + encoder_extended_attention_mask.to(hidden_states.device) + ) if encoder_decoder_position_bias is not None: - encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + encoder_decoder_position_bias = encoder_decoder_position_bias.to( + hidden_states.device + ) if layer_head_mask is not None: layer_head_mask = layer_head_mask.to(hidden_states.device) if cross_attn_layer_head_mask is not None: - cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to( + hidden_states.device + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -1130,10 +1296,14 @@ def forward( # (cross-attention position bias), (cross-attention weights) position_bias = layer_outputs[2] if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + encoder_decoder_position_bias = layer_outputs[ + 4 if output_attentions else 3 + ] # append next layer key value states if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) + present_key_value_states = present_key_value_states + ( + present_key_value_state, + ) if output_attentions: all_attentions = all_attentions + (layer_outputs[3],) @@ -1433,7 +1603,9 @@ class PreTrainedModel self.encoder.layer[layer].attention.prune_heads(heads) @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings( + output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC + ) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1477,7 +1649,9 @@ def forward( >>> last_hidden_states = outputs.last_hidden_state ```""" use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask if head_mask is not None and decoder_head_mask is None: @@ -1514,7 +1688,9 @@ def forward( if attention_mask is not None: attention_mask = attention_mask.to(self.decoder.first_device) if decoder_attention_mask is not None: - decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + decoder_attention_mask = decoder_attention_mask.to( + self.decoder.first_device + ) # Decode decoder_outputs = self.decoder( @@ -1547,12 +1723,18 @@ def forward( ) -@add_start_docstrings("""TIMESFM Model with a `language modeling` head on top.""", TIMESFM_START_DOCSTRING) +@add_start_docstrings( + """TIMESFM Model with a `language modeling` head on top.""", TIMESFM_START_DOCSTRING +) class TimesFMForConditionalGeneration(TimesFMPreTrainedModel): _keys_to_ignore_on_load_unexpected = [ "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", ] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + "lm_head.weight", + ] def __init__(self, config: TimesFMConfig): super().__init__(config) @@ -1642,7 +1824,9 @@ def get_decoder(self): return self.decoder @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings( + output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC + ) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1694,7 +1878,9 @@ def forward( >>> # studies have shown that owning a dog is good for you. ```""" use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask if head_mask is not None and decoder_head_mask is None: @@ -1726,7 +1912,11 @@ def forward( if self.model_parallel: torch.cuda.set_device(self.decoder.first_device) - if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + if ( + labels is not None + and decoder_input_ids is None + and decoder_inputs_embeds is None + ): # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) @@ -1739,7 +1929,9 @@ def forward( if attention_mask is not None: attention_mask = attention_mask.to(self.decoder.first_device) if decoder_attention_mask is not None: - decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + decoder_attention_mask = decoder_attention_mask.to( + self.decoder.first_device + ) # Decode decoder_outputs = self.decoder( @@ -1841,7 +2033,9 @@ def _reorder_cache(self, past_key_values, beam_idx): # if decoder past is not included in output # speedy decoding is disabled and no need to reorder if past_key_values is None: - logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + logger.warning( + "You might want to consider setting `use_cache=True` to speed up decoding" + ) return past_key_values reordered_decoder_past = () @@ -1852,7 +2046,9 @@ def _reorder_cache(self, past_key_values, beam_idx): for layer_past_state in layer_past_states: # need to set correct `past` for each of the four key / value states reordered_layer_past_states = reordered_layer_past_states + ( - layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), + layer_past_state.index_select( + 0, beam_idx.to(layer_past_state.device) + ), ) if reordered_layer_past_states[0].shape != layer_past_states[0].shape: @@ -1864,7 +2060,9 @@ def _reorder_cache(self, past_key_values, beam_idx): f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched" ) - reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + reordered_decoder_past = reordered_decoder_past + ( + reordered_layer_past_states, + ) return reordered_decoder_past @@ -1902,7 +2100,9 @@ def get_encoder(self): return self.encoder @add_start_docstrings_to_model_forward(TIMESFM_ENCODER_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings( + output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC + ) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1929,7 +2129,9 @@ def forward( >>> outputs = model(input_ids=input_ids) >>> last_hidden_states = outputs.last_hidden_state ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) encoder_outputs = self.encoder( input_ids=input_ids, @@ -1952,7 +2154,9 @@ def forward( TIMESFM_START_DOCSTRING, ) class TimesFMForSequenceClassification(TimesFMPreTrainedModel): - _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] + _keys_to_ignore_on_load_unexpected = [ + "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight" + ] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] def __init__(self, config: TimesFMConfig): @@ -1966,7 +2170,9 @@ def __init__(self, config: TimesFMConfig): self.model_parallel = False @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Seq2SeqSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings( + output_type=Seq2SeqSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC + ) def forward( self, input_ids: torch.LongTensor = None, @@ -1991,7 +2197,9 @@ def forward( config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). Returns: """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) if labels is not None: use_cache = False @@ -2033,7 +2241,9 @@ def forward( if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: raise ValueError("All examples must have the same number of tokens.") batch_size, _, hidden_size = sequence_output.shape - sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :] + sentence_representation = sequence_output[eos_mask, :].view( + batch_size, -1, hidden_size + )[:, -1, :] logits = self.classification_head(sentence_representation) loss = None @@ -2042,7 +2252,9 @@ def forward( if self.config.problem_type is None: if self.config.num_labels == 1: self.config.problem_type = "regression" - elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + elif self.config.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" @@ -2055,7 +2267,9 @@ def forward( loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + loss = loss_fct( + logits.view(-1, self.config.num_labels), labels.view(-1) + ) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) @@ -2098,7 +2312,9 @@ def __init__(self, config: TimesFMConfig): self.post_init() @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings( + output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC + ) def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -2115,7 +2331,9 @@ def forward( Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Returns: """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) outputs = self.transformer( input_ids, @@ -2156,7 +2374,9 @@ def forward( TIMESFM_START_DOCSTRING, ) class TimesFMForQuestionAnswering(TimesFMPreTrainedModel): - _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] + _keys_to_ignore_on_load_unexpected = [ + "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight" + ] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] def __init__(self, config: TimesFMConfig): @@ -2205,7 +2425,9 @@ def get_decoder(self): return self.decoder @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Seq2SeqQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings( + output_type=Seq2SeqQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC + ) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -2236,7 +2458,9 @@ def forward( are not taken into account for computing the loss. Returns: """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) use_cache = use_cache if use_cache is not None else self.config.use_cache if start_positions is not None and end_positions is not None: use_cache = False @@ -2253,7 +2477,9 @@ def forward( decoder_input_ids = self._shift_right(input_ids) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask if head_mask is not None and decoder_head_mask is None: From f79803cbffd4b7fc8fd9c5e9fb140d2909a435eb Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Thu, 5 Sep 2024 15:49:34 -0700 Subject: [PATCH 090/242] remove unneeded modules --- src/transformers/__init__.py | 14 +- src/transformers/models/auto/modeling_auto.py | 10 +- .../models/auto/tokenization_auto.py | 6 - src/transformers/models/timesfm/__init__.py | 14 +- .../models/timesfm/modeling_timesfm.py | 512 +----------------- tests/models/timesfm/test_modeling_timesfm.py | 156 +----- 6 files changed, 33 insertions(+), 679 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 6a09848414a8..1e36fe03fbc1 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3710,14 +3710,9 @@ ) _import_structure["models.timesfm"].extend( [ - "TimesFMEncoderModel", - "TimesFMForConditionalGeneration", - "TimesFMForQuestionAnswering", - "TimesFMForSequenceClassification", - "TimesFMForTokenClassification", + "TimesFMForPrediction", "TimesFMModel", "TimesFMPreTrainedModel", - "load_tf_weights_in_timesfm", ] ) _import_structure["models.table_transformer"].extend( @@ -8436,14 +8431,9 @@ load_tf_weights_in_t5, ) from .models.timesfm import ( - TimesFMEncoderModel, - TimesFMForConditionalGeneration, - TimesFMForQuestionAnswering, - TimesFMForSequenceClassification, - TimesFMForTokenClassification, + TimesFMForPrediction, TimesFMModel, TimesFMPreTrainedModel, - load_tf_weights_in_timesfm, ) from .models.table_transformer import ( TableTransformerForObjectDetection, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 0fe7488a4605..b5371967eb02 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -380,7 +380,7 @@ ("squeezebert", "SqueezeBertForMaskedLM"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), - ("timesfm", "TimesFMForConditionalGeneration"), + ("timesfm", "TimesFMForPrediction"), ("tapas", "TapasForMaskedLM"), ("transfo-xl", "TransfoXLLMHeadModel"), ("tvlt", "TvltForPreTraining"), @@ -475,7 +475,7 @@ ("squeezebert", "SqueezeBertForMaskedLM"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), - ("timesfm", "TimesFMForConditionalGeneration"), + ("timesfm", "TimesFMForPrediction"), ("tapas", "TapasForMaskedLM"), ("transfo-xl", "TransfoXLLMHeadModel"), ("wav2vec2", "Wav2Vec2ForMaskedLM"), @@ -963,7 +963,7 @@ ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToText"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), - ("timesfm", "TimesFMForConditionalGeneration"), + ("timesfm", "TimesFMForPrediction"), ("umt5", "UMT5ForConditionalGeneration"), ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"), ] @@ -1069,7 +1069,6 @@ ("stablelm", "StableLmForSequenceClassification"), ("starcoder2", "Starcoder2ForSequenceClassification"), ("t5", "T5ForSequenceClassification"), - ("timesfm", "TimesFMForSequenceClassification"), ("tapas", "TapasForSequenceClassification"), ("transfo-xl", "TransfoXLForSequenceClassification"), ("umt5", "UMT5ForSequenceClassification"), @@ -1149,7 +1148,6 @@ ("splinter", "SplinterForQuestionAnswering"), ("squeezebert", "SqueezeBertForQuestionAnswering"), ("t5", "T5ForQuestionAnswering"), - ("timesfm", "TimesFMForQuestionAnswering"), ("umt5", "UMT5ForQuestionAnswering"), ("xlm", "XLMForQuestionAnsweringSimple"), ("xlm-roberta", "XLMRobertaForQuestionAnswering"), @@ -1254,7 +1252,6 @@ ("stablelm", "StableLmForTokenClassification"), ("starcoder2", "Starcoder2ForTokenClassification"), ("t5", "T5ForTokenClassification"), - ("timesfm", "TimesFMForTokenClassification"), ("umt5", "UMT5ForTokenClassification"), ("xlm", "XLMForTokenClassification"), ("xlm-roberta", "XLMRobertaForTokenClassification"), @@ -1483,7 +1480,6 @@ ("roformer", "RoFormerModel"), ("squeezebert", "SqueezeBertModel"), ("t5", "T5EncoderModel"), - ("timesfm", "TimesFMEncoderModel"), ("umt5", "UMT5EncoderModel"), ("xlm", "XLMModel"), ("xlm-roberta", "XLMRobertaModel"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index d77d40720795..61c2c2e23d2f 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -510,12 +510,6 @@ "T5TokenizerFast" if is_tokenizers_available() else None, ), ), - ( - "timesfm", - ( - "T5Tokenizer" if is_sentencepiece_available() else None, - "T5TokenizerFast" if is_tokenizers_available() else None, - ), ("tapas", ("TapasTokenizer", None)), ("tapex", ("TapexTokenizer", None)), ("transfo-xl", ("TransfoXLTokenizer", None)), diff --git a/src/transformers/models/timesfm/__init__.py b/src/transformers/models/timesfm/__init__.py index 1abef3d3e175..baa30b11af21 100644 --- a/src/transformers/models/timesfm/__init__.py +++ b/src/transformers/models/timesfm/__init__.py @@ -30,14 +30,9 @@ pass else: _import_structure["modeling_timesfm"] = [ - "TimesFMEncoderModel", - "TimesFMForConditionalGeneration", + "TimesFMForPrediction", "TimesFMModel", "TimesFMPreTrainedModel", - "load_tf_weights_in_timesfm", - "TimesFMForQuestionAnswering", - "TimesFMForSequenceClassification", - "TimesFMForTokenClassification", ] if TYPE_CHECKING: @@ -50,14 +45,9 @@ pass else: from .modeling_timesfm import ( - TimesFMEncoderModel, - TimesFMForConditionalGeneration, - TimesFMForQuestionAnswering, - TimesFMForSequenceClassification, - TimesFMForTokenClassification, + TimesFMForPrediction, TimesFMModel, TimesFMPreTrainedModel, - load_tf_weights_in_timesfm, ) else: diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 8542c54277bc..2d35ecdeeca3 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -31,9 +31,6 @@ BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, - Seq2SeqQuestionAnsweringModelOutput, - Seq2SeqSequenceClassifierOutput, - TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ( @@ -978,9 +975,7 @@ def _init_weights(self, module): module, ( TimesFMModel, - TimesFMForConditionalGeneration, - TimesFMEncoderModel, - TimesFMForQuestionAnswering, + TimesFMForPrediction, ), ): # Mesh TensorFlow embeddings initialization @@ -993,10 +988,6 @@ def _init_weights(self, module): mean=0.0, std=factor * ((self.config.d_model) ** -0.5) ) module.qa_outputs.bias.data.zero_() - elif isinstance(module, TimesFMForTokenClassification): - if hasattr(module, "classifier"): - module.classifier.weight.data.normal_(mean=0.0, std=factor * 1.0) - module.classifier.bias.data.zero_() elif isinstance(module, TimesFMClassificationHead): module.dense.weight.data.normal_( mean=0.0, std=factor * ((self.config.d_model) ** -0.5) @@ -1726,7 +1717,7 @@ def forward( @add_start_docstrings( """TIMESFM Model with a `language modeling` head on top.""", TIMESFM_START_DOCSTRING ) -class TimesFMForConditionalGeneration(TimesFMPreTrainedModel): +class TimesFMForPrediction(TimesFMPreTrainedModel): _keys_to_ignore_on_load_unexpected = [ "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", ] @@ -2064,502 +2055,3 @@ def _reorder_cache(self, past_key_values, beam_idx): reordered_layer_past_states, ) return reordered_decoder_past - - -@add_start_docstrings( - "The bare TIMESFM Model transformer outputting encoder's raw hidden-states without any specific head on top.", - TIMESFM_START_DOCSTRING, -) -class TimesFMEncoderModel(TimesFMPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight"] - _keys_to_ignore_on_load_unexpected = [r"decoder"] - - def __init__(self, config: TimesFMConfig): - super().__init__(config) - self.freq_emb = nn.Embedding(config.freq_size, config.d_model) - self.horizon_ff_layer = TimesFMResidualBlock( - input_dims=config.d_model, - output_dims=config.horizon_len * (1 + len(config.quantiles)), - hidden_dims=config.d_ff, - dropout=config.dropout_rate, - ) - - encoder_config = copy.deepcopy(config) - encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False - self.encoder = TimesFMStack(encoder_config, self.shared) - - # Initialize weights and apply final processing - self.post_init() - - # Model parallel - self.model_parallel = False - self.device_map = None - - def get_encoder(self): - return self.encoder - - @add_start_docstrings_to_model_forward(TIMESFM_ENCODER_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: - r""" - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, TimesFMEncoderModel - - >>> tokenizer = AutoTokenizer.from_pretrained("google/timesfm-1.0-200m") - >>> model = TimesFMEncoderModel.from_pretrained("google/timesfm-1.0-200m") - >>> input_ids = tokenizer( - ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" - ... ).input_ids # Batch size 1 - >>> outputs = model(input_ids=input_ids) - >>> last_hidden_states = outputs.last_hidden_state - ```""" - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - return encoder_outputs - - -@add_start_docstrings( - """ - TIMESFM model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE - tasks. - """, - TIMESFM_START_DOCSTRING, -) -class TimesFMForSequenceClassification(TimesFMPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [ - "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight" - ] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] - - def __init__(self, config: TimesFMConfig): - super().__init__(config) - self.transformer = TimesFMModel(config) - self.classification_head = TimesFMClassificationHead(config) - - # Initialize weights and apply final processing - self.post_init() - - self.model_parallel = False - - @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=Seq2SeqSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC - ) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - decoder_head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - Returns: - """ - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - if labels is not None: - use_cache = False - - if input_ids is None and inputs_embeds is not None: - raise NotImplementedError( - f"Passing input embeddings is currently not supported for {self.__class__.__name__}" - ) - - # decoder_input_ids from input_ids if no decoder_input_ids are provided - if decoder_input_ids is None and decoder_inputs_embeds is None: - if input_ids is None: - raise ValueError( - "If no `decoder_input_ids` or `decoder_inputs_embeds` are " - "passed, `input_ids` cannot be `None`. Please pass either " - "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." - ) - decoder_input_ids = self._shift_right(input_ids) - - outputs = self.transformer( - input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - encoder_outputs=encoder_outputs, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - sequence_output = outputs[0] - - eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device) - - if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: - raise ValueError("All examples must have the same number of tokens.") - batch_size, _, hidden_size = sequence_output.shape - sentence_representation = sequence_output[eos_mask, :].view( - batch_size, -1, hidden_size - )[:, -1, :] - logits = self.classification_head(sentence_representation) - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.config.num_labels == 1: - self.config.problem_type = "regression" - elif self.config.num_labels > 1 and ( - labels.dtype == torch.long or labels.dtype == torch.int - ): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.config.num_labels == 1: - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct( - logits.view(-1, self.config.num_labels), labels.view(-1) - ) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits, labels) - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return Seq2SeqSequenceClassifierOutput( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - -@add_start_docstrings( - """ - TIMESFM Encoder Model with a token classification head on top (a linear layer on top of the hidden-states output) - e.g. for Named-Entity-Recognition (NER) tasks. - """, - TIMESFM_START_DOCSTRING, -) -class TimesFMForTokenClassification(TimesFMPreTrainedModel): - _tied_weights_keys = ["transformer.encoder.embed_tokens.weight"] - - def __init__(self, config: TimesFMConfig): - super().__init__(config) - self.num_labels = config.num_labels - - self.transformer = TimesFMEncoderModel(config) - self.dropout = nn.Dropout(config.classifier_dropout) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC - ) - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - Returns: - """ - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - outputs = self.transformer( - input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - hidden_states = self.dropout(hidden_states) - logits = self.classifier(hidden_states) - - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits, outputs[2:-1]) - return ((loss,) + output) if loss is not None else output - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - TIMESFM Model with a span classification head on top for extractive question-answering tasks like SQuAD (linear layers - on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - TIMESFM_START_DOCSTRING, -) -class TimesFMForQuestionAnswering(TimesFMPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [ - "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight" - ] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] - - def __init__(self, config: TimesFMConfig): - super().__init__(config) - self.model_dim = config.d_model - - self.shared = nn.Embedding(config.vocab_size, config.d_model) - - encoder_config = copy.deepcopy(config) - encoder_config.is_decoder = False - encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False - self.encoder = TimesFMStack(encoder_config, self.shared) - - decoder_config = copy.deepcopy(config) - decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False - decoder_config.num_layers = config.num_decoder_layers - self.decoder = TimesFMStack(decoder_config, self.shared) - - self.num_labels = config.num_labels - self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - self.model_parallel = False - - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, new_embeddings): - self.shared = new_embeddings - self.encoder.set_input_embeddings(new_embeddings) - self.decoder.set_input_embeddings(new_embeddings) - - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) - self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) - - def get_encoder(self): - return self.encoder - - def get_decoder(self): - return self.decoder - - @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=Seq2SeqQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.BoolTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - decoder_head_mask: Optional[torch.FloatTensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, - start_positions: Optional[torch.LongTensor] = None, - end_positions: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]: - r""" - start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence - are not taken into account for computing the loss. - Returns: - """ - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - if start_positions is not None and end_positions is not None: - use_cache = False - - # different to other models, TIMESFM automatically creates decoder_input_ids from - # input_ids if no decoder_input_ids are provided - if decoder_input_ids is None and decoder_inputs_embeds is None: - if input_ids is None: - raise ValueError( - "If no `decoder_input_ids` or `decoder_inputs_embeds` are " - "passed, `input_ids` cannot be `None`. Please pass either " - "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." - ) - decoder_input_ids = self._shift_right(input_ids) - - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask - if head_mask is not None and decoder_head_mask is None: - if self.config.num_layers == self.config.num_decoder_layers: - warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) - decoder_head_mask = head_mask - - # Encode if needed (training, first prediction pass) - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - - hidden_states = encoder_outputs[0] - - # Decode - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - inputs_embeds=decoder_inputs_embeds, - past_key_values=None, - encoder_hidden_states=hidden_states, - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = decoder_outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1).contiguous() - end_logits = end_logits.squeeze(-1).contiguous() - - total_loss = None - if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1).to(start_logits.device) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1).to(end_logits.device) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - start_positions = start_positions.clamp(0, ignored_index) - end_positions = end_positions.clamp(0, ignored_index) - - loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 - - if not return_dict: - output = (start_logits, end_logits) + decoder_outputs[1:] + encoder_outputs - return ((total_loss,) + output) if total_loss is not None else output - - return Seq2SeqQuestionAnsweringModelOutput( - loss=total_loss, - start_logits=start_logits, - end_logits=end_logits, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index e5878f8c51c7..e08277fac50f 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -21,7 +21,6 @@ import unittest from transformers import TimesFMConfig, is_torch_available -from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES from transformers.testing_utils import ( require_accelerate, require_sentencepiece, @@ -48,11 +47,7 @@ from transformers import ( AutoTokenizer, ByT5Tokenizer, - TimesFMEncoderModel, - TimesFMForConditionalGeneration, - TimesFMForQuestionAnswering, - TimesFMForSequenceClassification, - TimesFMForTokenClassification, + TimesFMForPrediction, TimesFMModel, T5Tokenizer, ) @@ -249,7 +244,7 @@ def create_and_check_with_lm_head( decoder_attention_mask, lm_labels, ): - model = TimesFMForConditionalGeneration(config=config).to(torch_device).eval() + model = TimesFMForPrediction(config=config).to(torch_device).eval() outputs = model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, @@ -260,26 +255,6 @@ def create_and_check_with_lm_head( self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size)) self.parent.assertEqual(outputs["loss"].size(), ()) - def create_and_check_with_sequence_classification_head( - self, - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ): - labels = torch.tensor([1] * self.batch_size, dtype=torch.long, device=torch_device) - model = TimesFMForSequenceClassification(config=config).to(torch_device).eval() - outputs = model( - input_ids=input_ids, - decoder_input_ids=input_ids, - labels=labels, - ) - # self.parent.assertEqual(len(outputs), 4) - self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, config.num_labels)) - self.parent.assertEqual(outputs["loss"].size(), ()) - def create_and_check_decoder_model_past( self, config, @@ -415,7 +390,7 @@ def create_and_check_generate_with_past_key_values( decoder_attention_mask, lm_labels, ): - model = TimesFMForConditionalGeneration(config=config).to(torch_device).eval() + model = TimesFMForPrediction(config=config).to(torch_device).eval() torch.manual_seed(0) output_without_past_cache = model.generate( input_ids[:1], num_beams=2, max_length=5, do_sample=True, use_cache=False @@ -446,7 +421,7 @@ def create_and_check_encoder_decoder_shared_weights( decoder_attention_mask, lm_labels, ): - for model_class in [TimesFMModel, TimesFMForConditionalGeneration]: + for model_class in [TimesFMModel, TimesFMForPrediction]: torch.manual_seed(0) model = model_class(config=config).to(torch_device).eval() # load state dict copies weights but does not tie them @@ -520,7 +495,7 @@ def check_resize_embeddings_timesfm_v1_1( prev_vocab_size = config.vocab_size config.tie_word_embeddings = False - model = TimesFMForConditionalGeneration(config=config).to(torch_device).eval() + model = TimesFMForPrediction(config=config).to(torch_device).eval() model.resize_token_embeddings(prev_vocab_size - 10) self.parent.assertEqual(model.get_input_embeddings().weight.shape[0], prev_vocab_size - 10) @@ -551,12 +526,12 @@ def prepare_config_and_inputs_for_common(self): @require_torch class TimesFMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - (TimesFMModel, TimesFMForConditionalGeneration, TimesFMForSequenceClassification, TimesFMForQuestionAnswering) + (TimesFMModel, TimesFMForPrediction) if is_torch_available() else () ) - all_generative_model_classes = (TimesFMForConditionalGeneration,) if is_torch_available() else () - all_parallelizable_model_classes = (TimesFMModel, TimesFMForConditionalGeneration) if is_torch_available() else () + all_generative_model_classes = (TimesFMForPrediction,) if is_torch_available() else () + all_parallelizable_model_classes = (TimesFMModel, TimesFMForPrediction) if is_torch_available() else () fx_compatible = False test_pruning = False test_resize_embeddings = True @@ -573,7 +548,7 @@ def setUp(self): def test_inputs_embeds(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in (TimesFMModel, TimesFMForConditionalGeneration, TimesFMForQuestionAnswering): + for model_class in (TimesFMModel, TimesFMForPrediction): model = model_class(config) model.to(torch_device) model.eval() @@ -609,10 +584,6 @@ def test_with_lm_head(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_with_lm_head(*config_and_inputs) - def test_with_sequence_classification_head(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_with_sequence_classification_head(*config_and_inputs) - def test_decoder_model_past(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_decoder_model_past(*config_and_inputs) @@ -695,7 +666,7 @@ def test_generate_with_head_masking(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() config = config_and_inputs[0] max_length = config_and_inputs[1].shape[-1] + 3 - model = TimesFMForConditionalGeneration(config).eval() + model = TimesFMForPrediction(config).eval() model.to(torch_device) head_masking = { @@ -799,50 +770,6 @@ def prepare_config_and_inputs(self): attention_mask, ) - def create_and_check_model( - self, - config, - input_ids, - attention_mask, - ): - model = TimesFMEncoderModel(config=config) - model.to(torch_device) - model.eval() - result = model( - input_ids=input_ids, - attention_mask=attention_mask, - ) - result = model(input_ids=input_ids) - encoder_output = result.last_hidden_state - - self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size)) - - def create_and_check_model_fp16_forward( - self, - config, - input_ids, - attention_mask, - ): - model = TimesFMEncoderModel(config=config).to(torch_device).half().eval() - output = model(input_ids, attention_mask=attention_mask)["last_hidden_state"] - self.parent.assertFalse(torch.isnan(output).any().item()) - - def create_and_check_with_token_classification_head( - self, - config, - input_ids, - attention_mask, - ): - labels = torch.tensor([1] * self.seq_length * self.batch_size, dtype=torch.long, device=torch_device) - model = TimesFMForTokenClassification(config=config).to(torch_device).eval() - outputs = model( - input_ids=input_ids, - labels=labels, - attention_mask=attention_mask, - ) - self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.seq_length, config.num_labels)) - self.parent.assertEqual(outputs["loss"].size(), ()) - def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -858,41 +785,6 @@ def prepare_config_and_inputs_for_common(self): return config, inputs_dict -class TimesFMEncoderOnlyModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = (TimesFMEncoderModel, TimesFMForTokenClassification) if is_torch_available() else () - test_pruning = False - test_resize_embeddings = False - test_model_parallel = True - pipeline_model_mapping = ( - { - "token-classification": TimesFMForTokenClassification, - } - if is_torch_available() - else {} - ) - all_parallelizable_model_classes = (TimesFMEncoderModel,) if is_torch_available() else () - - def setUp(self): - self.model_tester = TimesFMEncoderOnlyModelTester(self) - self.config_tester = ConfigTester(self, config_class=TimesFMConfig, d_model=37) - - def test_config(self): - self.config_tester.run_common_tests() - - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - @unittest.skipIf(torch_device == "cpu", "Cant do half precision") - def test_model_fp16_forward(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) - - def test_with_token_classification_head(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_with_token_classification_head(*config_and_inputs) - - def use_task_specific_params(model, task): model.config.update(model.config.task_specific_params[task]) @@ -922,38 +814,38 @@ def import_accelerate_mock(name, *args, **kwargs): with unittest.mock.patch("builtins.__import__", side_effect=import_accelerate_mock): accelerate_available = False - model = TimesFMForConditionalGeneration.from_pretrained("google/timesfm-1.0-200m", torch_dtype=torch.float16) + model = TimesFMForPrediction.from_pretrained("google/timesfm-1.0-200m", torch_dtype=torch.float16) self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16) # Load without in bf16 - model = TimesFMForConditionalGeneration.from_pretrained("google/timesfm-1.0-200m", torch_dtype=torch.bfloat16) + model = TimesFMForPrediction.from_pretrained("google/timesfm-1.0-200m", torch_dtype=torch.bfloat16) self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) # Load using `accelerate` in bf16 - model = TimesFMForConditionalGeneration.from_pretrained( + model = TimesFMForPrediction.from_pretrained( "google/timesfm-1.0-200m", torch_dtype=torch.bfloat16, device_map="auto" ) self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) # Load using `accelerate` in bf16 - model = TimesFMForConditionalGeneration.from_pretrained( + model = TimesFMForPrediction.from_pretrained( "google/timesfm-1.0-200m", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True ) self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) # Load without using `accelerate` - model = TimesFMForConditionalGeneration.from_pretrained( + model = TimesFMForPrediction.from_pretrained( "google/timesfm-1.0-200m", torch_dtype=torch.float16, low_cpu_mem_usage=True ) self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16) # Load using `accelerate` - model = TimesFMForConditionalGeneration.from_pretrained( + model = TimesFMForPrediction.from_pretrained( "google/timesfm-1.0-200m", torch_dtype=torch.float16, device_map="auto" ) self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) @@ -966,7 +858,7 @@ def import_accelerate_mock(name, *args, **kwargs): class TimesFMModelIntegrationTests(unittest.TestCase): @cached_property def model(self): - return TimesFMForConditionalGeneration.from_pretrained("google-timesfm/timesfm-base").to(torch_device) + return TimesFMForPrediction.from_pretrained("google-timesfm/timesfm-base").to(torch_device) @cached_property def tokenizer(self): @@ -979,7 +871,7 @@ def test_torch_quant(self): """ model_name = "google/flan-timesfm-small" tokenizer = T5Tokenizer.from_pretrained(model_name) - model = TimesFMForConditionalGeneration.from_pretrained(model_name) + model = TimesFMForPrediction.from_pretrained(model_name) model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) input_text = "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?" input_ids = tokenizer(input_text, return_tensors="pt").input_ids @@ -987,7 +879,7 @@ def test_torch_quant(self): @slow def test_small_generation(self): - model = TimesFMForConditionalGeneration.from_pretrained("google/timesfm-1.0-200m").to(torch_device) + model = TimesFMForPrediction.from_pretrained("google/timesfm-1.0-200m").to(torch_device) model.config.max_length = 8 model.config.num_beams = 1 model.config.do_sample = False @@ -1014,7 +906,7 @@ def test_small_integration_test(self): >>> score = timesfm_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) """ - model = TimesFMForConditionalGeneration.from_pretrained("google/timesfm-1.0-200m").to(torch_device) + model = TimesFMForPrediction.from_pretrained("google/timesfm-1.0-200m").to(torch_device) tokenizer = T5Tokenizer.from_pretrained("google/timesfm-1.0-200m") input_ids = tokenizer("Hello there", return_tensors="pt").input_ids @@ -1040,7 +932,7 @@ def test_small_v1_1_integration_test(self): >>> score = timesfm_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) """ - model = TimesFMForConditionalGeneration.from_pretrained("google/timesfm-v1_1-small").to(torch_device) + model = TimesFMForPrediction.from_pretrained("google/timesfm-v1_1-small").to(torch_device) tokenizer = T5Tokenizer.from_pretrained("google/timesfm-v1_1-small") input_ids = tokenizer("Hello there", return_tensors="pt").input_ids @@ -1064,7 +956,7 @@ def test_small_bytimesfm_integration_test(self): >>> score = timesfm_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) """ - model = TimesFMForConditionalGeneration.from_pretrained("google/bytimesfm-small").to(torch_device) + model = TimesFMForPrediction.from_pretrained("google/bytimesfm-small").to(torch_device) tokenizer = ByT5Tokenizer.from_pretrained("google/bytimesfm-small") input_ids = tokenizer("Hello there", return_tensors="pt").input_ids @@ -1405,7 +1297,7 @@ def test_contrastive_search_timesfm(self): ) article = "summarize: " + article.strip() timesfm_tokenizer = AutoTokenizer.from_pretrained("flax-community/timesfm-base-cnn-dm") - timesfm_model = TimesFMForConditionalGeneration.from_pretrained("flax-community/timesfm-base-cnn-dm").to(torch_device) + timesfm_model = TimesFMForPrediction.from_pretrained("flax-community/timesfm-base-cnn-dm").to(torch_device) input_ids = timesfm_tokenizer( article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt" ).input_ids.to(torch_device) @@ -1434,7 +1326,7 @@ def build_model_and_check_forward_pass(self, **kwargs): decoder_attention_mask, lm_labels, ) = inputs - model = TimesFMForConditionalGeneration(config=config).to(torch_device).eval() + model = TimesFMForPrediction(config=config).to(torch_device).eval() outputs = model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, From a81e99b1e2a158af0a7421228edf58d37ba72af4 Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Tue, 10 Sep 2024 19:30:18 -0700 Subject: [PATCH 091/242] TimesFM Model --- .../models/timesfm/configuration_timesfm.py | 8 +- .../models/timesfm/modeling_timesfm.py | 1445 ++--------------- 2 files changed, 107 insertions(+), 1346 deletions(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index cdee64d7f377..16da290cc0cb 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -47,18 +47,18 @@ class TimesFMConfig(PretrainedConfig): The tolerance for the quantile loss. freq_size (`int`, *optional*, defaults to 3): The number of frequency embeddings. - d_model (`int`, *optional*, defaults to 512): + d_model (`int`, *optional*, defaults to 1280): Size of the encoder layers and the pooler layer. - d_kv (`int`, *optional*, defaults to 64): + d_kv (`int`, *optional*, defaults to 80): Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will be defined as `num_heads * d_kv`. d_ff (`int`, *optional*, defaults to 1280): Size of the intermediate feed forward layer in each `TimesFMBlock`. - num_layers (`int`, *optional*, defaults to 6): + num_layers (`int`, *optional*, defaults to 20): Number of hidden layers in the Transformer encoder. num_decoder_layers (`int`, *optional*): Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set. - num_heads (`int`, *optional*, defaults to 8): + num_heads (`int`, *optional*, defaults to 16): Number of attention heads for each attention layer in the Transformer encoder. relative_attention_num_buckets (`int`, *optional*, defaults to 32): The number of buckets to use for each attention layer. diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 2d35ecdeeca3..852d91320889 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -56,134 +56,6 @@ _CONFIG_FOR_DOC = "TimesFMConfig" _CHECKPOINT_FOR_DOC = "google/timesfm-1.0-200m" -#################################################### -# This dict contains ids and associated url -# for the pretrained weights provided with the models -#################################################### - - -#################################################### -# This is a conversion method from TF 1.0 to PyTorch -# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 -#################################################### -# Copied from transformers.models.t5.modeling_t5.load_tf_weights_in_t5 with t5->timesfm -def load_tf_weights_in_timesfm(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - tf_weights = {} - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - tf_weights[name] = array - - for txt_name in names: - name = txt_name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n - in [ - "adam_v", - "adam_m", - "AdamWeightDecayOptimizer", - "AdamWeightDecayOptimizer_1", - "global_step", - ] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - tf_weights.pop(txt_name, None) - continue - if "_slot_" in name[-1]: - logger.info(f"Skipping {'/'.join(name)}") - tf_weights.pop(txt_name, None) - continue - pointer = model - array = tf_weights[txt_name] - - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] in ["kernel", "scale", "embedding"]: - pointer = getattr(pointer, "weight") - elif scope_names[0] == "self_attention": - pointer = getattr(pointer, "layer") - pointer = pointer[0] - elif scope_names[0] == "enc_dec_attention": - pointer = getattr(pointer, "layer") - pointer = pointer[1] - elif scope_names[0] == "dense_relu_dense": - pointer = getattr(pointer, "layer") - pointer = pointer[2] - elif scope_names[0] == "rms_norm": - if hasattr(pointer, "layer_norm"): - pointer = getattr(pointer, "layer_norm") - elif hasattr(pointer, "final_layer_norm"): - pointer = getattr(pointer, "final_layer_norm") - elif scope_names[0] == "scale": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - elif scope_names[0] == "decoder" and name[1] == "logits": - continue - elif scope_names[0] == "logits": - pointer = getattr(pointer, "lm_head") - elif ( - scope_names[0] == "wi" - and len(scope_names) > 1 - and scope_names[1].isdigit() - ): - pointer = getattr(pointer, f"wi_{scope_names[1]}") - continue - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if scope_names[0] not in ["kernel", "scale", "embedding"]: - pointer = getattr(pointer, "weight") - if scope_names[0] != "embedding": - logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") - array = np.transpose(array) - try: - if pointer.shape != array.shape: - raise ValueError( - f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" - ) - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array.astype(np.float32)) - tf_weights.pop(txt_name, None) - - logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") - return model - #################################################### # PyTorch Models are constructed by sub-classing @@ -696,642 +568,139 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): return outputs -# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->TimesFM -class TimesFMLayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): - super().__init__() - self.SelfAttention = TimesFMAttention( - config, has_relative_attention_bias=has_relative_attention_bias - ) - self.layer_norm = TimesFMLayerNorm( - config.d_model, eps=config.layer_norm_epsilon - ) - self.dropout = nn.Dropout(config.dropout_rate) - - def forward( - self, - hidden_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - past_key_value=None, - use_cache=False, - output_attentions=False, - ): - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.SelfAttention( - normed_hidden_states, - mask=attention_mask, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - ) - hidden_states = hidden_states + self.dropout(attention_output[0]) - outputs = (hidden_states,) + attention_output[ - 1: - ] # add attentions if we output them - return outputs - - -# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->TimesFM -class TimesFMLayerCrossAttention(nn.Module): - def __init__(self, config): +class TimesFMTransformerLayer(nn.Module): + def __init__(self, config: TimesFMConfig): super().__init__() - self.EncDecAttention = TimesFMAttention( - config, has_relative_attention_bias=False - ) - self.layer_norm = TimesFMLayerNorm( - config.d_model, eps=config.layer_norm_epsilon - ) + self.attention = TimesFMAttention(config) + self.ff = TimesFMLayerFF(config) + self.layer_norm = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) - def forward( - self, - hidden_states, - key_value_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - past_key_value=None, - use_cache=False, - query_length=None, - output_attentions=False, - ): - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.EncDecAttention( - normed_hidden_states, - mask=attention_mask, - key_value_states=key_value_states, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - query_length=query_length, - output_attentions=output_attentions, - ) - layer_output = hidden_states + self.dropout(attention_output[0]) - outputs = (layer_output,) + attention_output[ - 1: - ] # add attentions if we output them - return outputs + def forward(self, inputs, mask=None): + x = self.layer_norm(inputs) + x = self.attention(x, mask=mask) + x = self.dropout(x) + x = x + inputs + x = self.ff(x) + return x -# Copied from transformers.models.t5.modeling_t5.T5Block with T5->TimesFM -class TimesFMBlock(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): +class TimesFMTransformerStack(nn.Module): + def __init__(self, config: TimesFMConfig): super().__init__() - self.is_decoder = config.is_decoder - self.layer = nn.ModuleList() - self.layer.append( - TimesFMLayerSelfAttention( - config, has_relative_attention_bias=has_relative_attention_bias - ) + self.layers = nn.ModuleList( + [TimesFMTransformerLayer(config) for _ in range(config.num_layers)] ) - if self.is_decoder: - self.layer.append(TimesFMLayerCrossAttention(config)) - self.layer.append(TimesFMLayerFF(config)) - - def forward( - self, - hidden_states, - attention_mask=None, - position_bias=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - encoder_decoder_position_bias=None, - layer_head_mask=None, - cross_attn_layer_head_mask=None, - past_key_value=None, - use_cache=False, - output_attentions=False, - return_dict=True, - ): - if past_key_value is not None: - if not self.is_decoder: - logger.warning( - "`past_key_values` is passed to the encoder. Please make sure this is intended." - ) - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + def forward(self, hidden_states, mask=None): + for layer in self.layers: + hidden_states = layer(hidden_states, mask=mask) + return hidden_states - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None +class TimesFMModel(PreTrainedModel): + def __init__(self, config: TimesFMConfig): + super().__init__(config) - self_attention_outputs = self.layer[0]( - hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, + self.freq_emb = nn.Embedding( + num_embeddings=config.freq_size, + embedding_dim=config.d_model, + ) + self.position_emb = TimesFMPositionalEmbedding( + embedding_dims=config.d_model, ) - hidden_states, present_key_value_state = self_attention_outputs[:2] - attention_outputs = self_attention_outputs[ - 2: - ] # Keep self-attention outputs and relative position weights - - # clamp inf values to enable fp16 training - if hidden_states.dtype == torch.float16: - clamp_value = torch.where( - torch.isinf(hidden_states).any(), - torch.finfo(hidden_states.dtype).max - 1000, - torch.finfo(hidden_states.dtype).max, - ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value - ) - - do_cross_attention = self.is_decoder and encoder_hidden_states is not None - if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None - - cross_attention_outputs = self.layer[1]( - hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - position_bias=encoder_decoder_position_bias, - layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - query_length=query_length, - use_cache=use_cache, - output_attentions=output_attentions, - ) - hidden_states = cross_attention_outputs[0] - - # clamp inf values to enable fp16 training - if hidden_states.dtype == torch.float16: - clamp_value = torch.where( - torch.isinf(hidden_states).any(), - torch.finfo(hidden_states.dtype).max - 1000, - torch.finfo(hidden_states.dtype).max, - ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value - ) - - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = ( - present_key_value_state + cross_attention_outputs[1] - ) - - # Keep cross-attention outputs and relative position weights - attention_outputs = attention_outputs + cross_attention_outputs[2:] - - # Apply Feed Forward layer - hidden_states = self.layer[-1](hidden_states) - - # clamp inf values to enable fp16 training - if hidden_states.dtype == torch.float16: - clamp_value = torch.where( - torch.isinf(hidden_states).any(), - torch.finfo(hidden_states.dtype).max - 1000, - torch.finfo(hidden_states.dtype).max, - ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value - ) - - outputs = (hidden_states,) - if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs - else: - outputs = outputs + attention_outputs + self.input_ff_layer = TimesFMResidualBlock( + input_dims=config.patch_len * 2, + output_dims=config.d_model, + hidden_dims=config.d_ff, + dropout=config.dropout_rate, + ) - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + self.stacked_transformer_layer = TimesFMTransformerStack(config) + def preprocess_inputs(self, inputs): + assert len(inputs.shape) == 3 # (batch_size, num_patches, patch_len) + inputs_mean = inputs.mean(dim=(1, 2)) + inputs_std = inputs.std(dim=(1, 2)) + processed_input = (inputs - inputs_mean[:, None, None]) / inputs_std[ + :, None, None + ] + return processed_input, (inputs_mean, inputs_std) -# Copied from transformers.models.t5.modeling_t5.T5ClassificationHead with T5->TimesFM -class TimesFMClassificationHead(nn.Module): - """Head for sentence-level classification tasks.""" + def create_causal_mask(batch_size, seq_len): + mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool() + mask = mask.unsqueeze(0).unsqueeze(1) + mask = mask.expand(batch_size, 1, seq_len, seq_len) + mask = mask.float().masked_fill(mask, -2.3819763e38).masked_fill(~mask, 0.0) + return mask + def forward( + self, + input_ts, + ): + batch_size = input_ts.shape[0] + patched_inputs = input_ts.reshape(batch_size, -1, self.config.patch_len) + patched_pads = torch.zeros_like(patched_inputs) + patched_inputs, input_stats = self.preprocess_inputs(patched_inputs) + concat_inputs = torch.concat([patched_inputs, patched_pads], dim=-1) + + model_input = self.input_ff_layer(concat_inputs) + position_emb = self.position_emb(seq_length=model_input.shape[1]).expand( + model_input.shape[0], -1, -1 + ) + model_input = model_input + position_emb + f_emb = self.freq_emb( + torch.zeros((batch_size, 1), dtype=torch.long) + ) # freq set to zero, change if needed + model_input = model_input + f_emb + mask = self.create_causal_mask(model_input.shape[0], model_input.shape[1]) + model_output = self.stacked_transformer_layer(model_input, mask=mask) + return model_output, input_stats + + +class TimesFMPredictionHead(nn.Module): def __init__(self, config: TimesFMConfig): super().__init__() - self.dense = nn.Linear(config.d_model, config.d_model) - self.dropout = nn.Dropout(p=config.classifier_dropout) - self.out_proj = nn.Linear(config.d_model, config.num_labels) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.dropout(hidden_states) - hidden_states = self.dense(hidden_states) - hidden_states = torch.tanh(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.out_proj(hidden_states) - return hidden_states - - -# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel with T5->TimesFM,t5->timesfm -class TimesFMPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = TimesFMConfig - load_tf_weights = load_tf_weights_in_timesfm - base_model_prefix = "transformer" - is_parallelizable = True - supports_gradient_checkpointing = True - _no_split_modules = ["TimesFMBlock"] - _keep_in_fp32_modules = ["wo"] - - @property - def dummy_inputs(self): - input_ids = torch.tensor(DUMMY_INPUTS) - input_mask = torch.tensor(DUMMY_MASK) - dummy_inputs = { - "decoder_input_ids": input_ids, - "input_ids": input_ids, - "decoder_attention_mask": input_mask, - } - return dummy_inputs - - def _init_weights(self, module): - """Initialize the weights""" - factor = ( - self.config.initializer_factor - ) # Used for testing weights initialization - if isinstance(module, TimesFMLayerNorm): - module.weight.data.fill_(factor * 1.0) - elif isinstance( - module, - ( - TimesFMModel, - TimesFMForPrediction, - ), - ): - # Mesh TensorFlow embeddings initialization - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 - module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) - if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: - module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) - if hasattr(module, "qa_outputs"): - module.qa_outputs.weight.data.normal_( - mean=0.0, std=factor * ((self.config.d_model) ** -0.5) - ) - module.qa_outputs.bias.data.zero_() - elif isinstance(module, TimesFMClassificationHead): - module.dense.weight.data.normal_( - mean=0.0, std=factor * ((self.config.d_model) ** -0.5) - ) - if hasattr(module.dense, "bias") and module.dense.bias is not None: - module.dense.bias.data.zero_() - module.out_proj.weight.data.normal_( - mean=0.0, std=factor * ((self.config.d_model) ** -0.5) - ) - if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: - module.out_proj.bias.data.zero_() - elif isinstance(module, TimesFMDenseActDense): - # Mesh TensorFlow FF initialization - # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 - # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 - module.wi.weight.data.normal_( - mean=0.0, std=factor * ((self.config.d_model) ** -0.5) - ) - if hasattr(module.wi, "bias") and module.wi.bias is not None: - module.wi.bias.data.zero_() - module.wo.weight.data.normal_( - mean=0.0, std=factor * ((self.config.d_ff) ** -0.5) - ) - if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.data.zero_() - elif isinstance(module, TimesFMPerHeadDimScale): - module.scale.data.zero_() - elif isinstance(module, TimesFMAttention): - # Mesh TensorFlow attention initialization to avoid scaling before softmax - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 - d_model = self.config.d_model - key_value_proj_dim = self.config.d_kv - n_heads = self.config.num_heads - module.q.weight.data.normal_( - mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5) - ) - module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.o.weight.data.normal_( - mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5) - ) - if module.has_relative_attention_bias: - module.relative_attention_bias.weight.data.normal_( - mean=0.0, std=factor * ((d_model) ** -0.5) - ) - - def _shift_right(self, input_ids): - decoder_start_token_id = self.config.decoder_start_token_id - pad_token_id = self.config.pad_token_id + self.config = config + self.horizon_ff_layer = TimesFMResidualBlock( + input_dims=config.d_model, + output_dims=config.horizon_len, + hidden_dims=config.d_ff, + dropout=config.dropout_rate, + ) - if decoder_start_token_id is None: - raise ValueError( - "self.model.config.decoder_start_token_id has to be defined. In TimesFM it is usually set to the pad_token_id. " - "See TimesFM docs for more information." - ) + def postprocess_outputs(self, outputs, stats): + mean, std = stats + return outputs * std[:, None, None, None] + mean[:, None, None, None] - # shift inputs to the right - if is_torch_fx_proxy(input_ids): - # Item assignment is not supported natively for proxies. - shifted_input_ids = torch.full( - input_ids.shape[:-1] + (1,), decoder_start_token_id - ) - shifted_input_ids = torch.cat( - [shifted_input_ids, input_ids[..., :-1]], dim=-1 - ) - else: - shifted_input_ids = input_ids.new_zeros(input_ids.shape) - shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() - shifted_input_ids[..., 0] = decoder_start_token_id + def forward(self, model_output, input_stats): + batch_size = model_output.shape[0] + output_ts = self.horizon_ff_layer(model_output) - if pad_token_id is None: - raise ValueError("self.model.config.pad_token_id has to be defined.") - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + assert self.config.d_model % self.config.horizon_len == 0 + num_outputs = self.config.d_model // self.config.horizon_len - return shifted_input_ids + output_ts = output_ts.reshape( + batch_size, -1, self.config.horizon_len, num_outputs + ) + output_ts = self.postprocess_outputs(output_ts, input_stats) + return output_ts -class TimesFMStack(TimesFMPreTrainedModel): - def __init__(self, config): +class TimesFMForPrediction(PreTrainedModel): + def __init__(self, config: TimesFMConfig): super().__init__(config) - - self.is_decoder = config.is_decoder - - self.block = nn.ModuleList( - [ - TimesFMBlock(config, has_relative_attention_bias=bool(i == 0)) - for i in range(config.num_layers) - ] - ) - self.dropout = nn.Dropout(config.dropout_rate) - - # Initialize weights and apply final processing - self.post_init() - # Model parallel - self.device_map = None - self.gradient_checkpointing = False + self.timesfm = TimesFMModel(config) + self.prediction_head = TimesFMPredictionHead(config) def forward( self, - input_ids=None, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - inputs_embeds=None, - head_mask=None, - cross_attn_head_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, + input_ts, ): - # Model parallel - if self.model_parallel: - torch.cuda.set_device(self.first_device) - self.embed_tokens = self.embed_tokens.to(self.first_device) - use_cache = use_cache if use_cache is not None else self.config.use_cache - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if input_ids is not None and inputs_embeds is not None: - err_msg_prefix = "decoder_" if self.is_decoder else "" - raise ValueError( - f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" - ) - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - err_msg_prefix = "decoder_" if self.is_decoder else "" - raise ValueError( - f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds" - ) - - if inputs_embeds is None: - if self.embed_tokens is None: - raise ValueError( - "You have to initialize the model with valid token embeddings" - ) - inputs_embeds = self.embed_tokens(input_ids) - - batch_size, seq_length = input_shape - - # required mask seq length can be calculated via length of past - mask_seq_length = ( - past_key_values[0][0].shape[2] + seq_length - if past_key_values is not None - else seq_length - ) - - if use_cache is True: - if not self.is_decoder: - raise ValueError( - f"`use_cache` can only be set to `True` if {self} is used as a decoder" - ) - - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) - - if attention_mask is None: - attention_mask = torch.ones( - batch_size, mask_seq_length, device=inputs_embeds.device - ) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask( - attention_mask, input_shape - ) - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = ( - encoder_hidden_states.size() - ) - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones( - encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long - ) - encoder_extended_attention_mask = self.invert_attention_mask( - encoder_attention_mask - ) - else: - encoder_extended_attention_mask = None - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # Prepare head mask if needed - head_mask = self.get_head_mask(head_mask, self.config.num_layers) - cross_attn_head_mask = self.get_head_mask( - cross_attn_head_mask, self.config.num_layers - ) - present_key_value_states = () if use_cache else None - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - all_cross_attentions = () if (output_attentions and self.is_decoder) else None - position_bias = None - encoder_decoder_position_bias = None - - hidden_states = self.dropout(inputs_embeds) - - for i, (layer_module, past_key_value) in enumerate( - zip(self.block, past_key_values) - ): - layer_head_mask = head_mask[i] - cross_attn_layer_head_mask = cross_attn_head_mask[i] - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) - if position_bias is not None: - position_bias = position_bias.to(hidden_states.device) - if encoder_hidden_states is not None: - encoder_hidden_states = encoder_hidden_states.to( - hidden_states.device - ) - if encoder_extended_attention_mask is not None: - encoder_extended_attention_mask = ( - encoder_extended_attention_mask.to(hidden_states.device) - ) - if encoder_decoder_position_bias is not None: - encoder_decoder_position_bias = encoder_decoder_position_bias.to( - hidden_states.device - ) - if layer_head_mask is not None: - layer_head_mask = layer_head_mask.to(hidden_states.device) - if cross_attn_layer_head_mask is not None: - cross_attn_layer_head_mask = cross_attn_layer_head_mask.to( - hidden_states.device - ) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.forward, - hidden_states, - extended_attention_mask, - position_bias, - encoder_hidden_states, - encoder_extended_attention_mask, - encoder_decoder_position_bias, - layer_head_mask, - cross_attn_layer_head_mask, - None, # past_key_value is always None with gradient checkpointing - use_cache, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask=extended_attention_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=layer_head_mask, - cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - # layer_outputs is a tuple with: - # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) - if use_cache is False: - layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - - hidden_states, present_key_value_state = layer_outputs[:2] - - # We share the position biases between the layers - the first layer store them - # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), - # (cross-attention position bias), (cross-attention weights) - position_bias = layer_outputs[2] - if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[ - 4 if output_attentions else 3 - ] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + ( - present_key_value_state, - ) - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[3],) - if self.is_decoder: - all_cross_attentions = all_cross_attentions + (layer_outputs[5],) - - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) - - hidden_states = self.dropout(hidden_states) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - present_key_value_states, - all_hidden_states, - all_attentions, - all_cross_attentions, - ] - if v is not None - ) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=present_key_value_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) + model_output, input_stats = self.timesfm(input_ts) + output_ts = self.prediction_head(model_output, input_stats) + return output_ts TIMESFM_START_DOCSTRING = r""" @@ -1447,611 +816,3 @@ def forward( return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ - -TIMESFM_ENCODER_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. TIMESFM is a model with relative position embeddings so you - should be able to pad the inputs on both the right and the left. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for detail. - - To know more on how to prepare `input_ids` for pretraining take a look a [TIMESFM Training](./timesfm#training). - attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask -__HEAD_MASK_WARNING_MSG = """ -The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, -`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. -If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, -num_heads)`. -""" - - -@add_start_docstrings( - "The bare TIMESFM Model transformer outputting raw hidden-states without any specific head on top.", - TIMESFM_START_DOCSTRING, -) -class TimesFMModel(TimesFMPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [ - "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", - ] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] - - def __init__(self, config: TimesFMConfig): - super().__init__(config) - self.freq_emb = nn.Embedding(config.freq_size, config.d_model) - self.horizon_ff_layer = TimesFMResidualBlock( - input_dims=config.d_model, - output_dims=config.horizon_len * (1 + len(config.quantiles)), - hidden_dims=config.d_ff, - dropout=config.dropout_rate, - ) - - encoder_config = copy.deepcopy(config) - encoder_config.is_decoder = False - encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False - self.encoder = TimesFMStack(encoder_config, self.shared) - - decoder_config = copy.deepcopy(config) - decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False - decoder_config.num_layers = config.num_decoder_layers - self.decoder = TimesFMStack(decoder_config, self.shared) - - # Initialize weights and apply final processing - self.post_init() - - # Model parallel - self.model_parallel = False - self.device_map = None - - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def parallelize(self, device_map=None): - warnings.warn( - "`TimesFMModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model" - " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" - " `device_map` but it needs to be a dictionary module_name to device, so for instance {'encoder.block.0':" - " 0, 'encoder.block.1': 1, ...}", - FutureWarning, - ) - self.device_map = ( - get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) - if device_map is None - else device_map - ) - assert_device_map(self.device_map, len(self.encoder.block)) - self.encoder.parallelize(self.device_map) - self.decoder.parallelize(self.device_map) - self.model_parallel = True - - @add_start_docstrings(DEPARALLELIZE_DOCSTRING) - def deparallelize(self): - warnings.warn( - "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", - FutureWarning, - ) - self.encoder.deparallelize() - self.decoder.deparallelize() - self.encoder = self.encoder.to("cpu") - self.decoder = self.decoder.to("cpu") - self.model_parallel = False - self.device_map = None - torch.cuda.empty_cache() - - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, new_embeddings): - self.shared = new_embeddings - self.encoder.set_input_embeddings(new_embeddings) - self.decoder.set_input_embeddings(new_embeddings) - - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) - self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) - - def get_encoder(self): - return self.encoder - - def get_decoder(self): - return self.decoder - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.encoder.layer[layer].attention.prune_heads(heads) - - @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.BoolTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - decoder_head_mask: Optional[torch.FloatTensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - decoder_inputs_embeds: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: - r""" - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, TimesFMModel - - >>> tokenizer = AutoTokenizer.from_pretrained("google/timesfm-1.0-200m") - >>> model = TimesFMModel.from_pretrained("google/timesfm-1.0-200m") - - >>> input_ids = tokenizer( - ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" - ... ).input_ids # Batch size 1 - >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 - - >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for TimesFMModel. - >>> # This is not needed for torch's TimesFMForConditionalGeneration as it does this internally using labels arg. - >>> decoder_input_ids = model._shift_right(decoder_input_ids) - - >>> # forward pass - >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) - >>> last_hidden_states = outputs.last_hidden_state - ```""" - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask - if head_mask is not None and decoder_head_mask is None: - if self.config.num_layers == self.config.num_decoder_layers: - warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) - decoder_head_mask = head_mask - - # Encode if needed (training, first prediction pass) - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - - hidden_states = encoder_outputs[0] - - # Set device for model parallelism - if self.model_parallel: - torch.cuda.set_device(self.decoder.first_device) - hidden_states = hidden_states.to(self.decoder.first_device) - if decoder_input_ids is not None: - decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) - if attention_mask is not None: - attention_mask = attention_mask.to(self.decoder.first_device) - if decoder_attention_mask is not None: - decoder_attention_mask = decoder_attention_mask.to( - self.decoder.first_device - ) - - # Decode - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - inputs_embeds=decoder_inputs_embeds, - past_key_values=past_key_values, - encoder_hidden_states=hidden_states, - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return Seq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - -@add_start_docstrings( - """TIMESFM Model with a `language modeling` head on top.""", TIMESFM_START_DOCSTRING -) -class TimesFMForPrediction(TimesFMPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [ - "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", - ] - _tied_weights_keys = [ - "encoder.embed_tokens.weight", - "decoder.embed_tokens.weight", - "lm_head.weight", - ] - - def __init__(self, config: TimesFMConfig): - super().__init__(config) - self.model_dim = config.d_model - - self.shared = nn.Embedding(config.vocab_size, config.d_model) - - encoder_config = copy.deepcopy(config) - encoder_config.is_decoder = False - encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False - self.encoder = TimesFMStack(encoder_config, self.shared) - - decoder_config = copy.deepcopy(config) - decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False - decoder_config.num_layers = config.num_decoder_layers - self.decoder = TimesFMStack(decoder_config, self.shared) - - self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - # Model parallel - self.model_parallel = False - self.device_map = None - - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def parallelize(self, device_map=None): - warnings.warn( - "`TimesFMForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you" - " should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also" - " provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance" - " {'encoder.block.0': 0, 'encoder.block.1': 1, ...}", - FutureWarning, - ) - self.device_map = ( - get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) - if device_map is None - else device_map - ) - assert_device_map(self.device_map, len(self.encoder.block)) - self.encoder.parallelize(self.device_map) - self.decoder.parallelize(self.device_map) - self.lm_head = self.lm_head.to(self.decoder.first_device) - self.model_parallel = True - - @add_start_docstrings(DEPARALLELIZE_DOCSTRING) - def deparallelize(self): - warnings.warn( - "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", - FutureWarning, - ) - self.encoder.deparallelize() - self.decoder.deparallelize() - self.encoder = self.encoder.to("cpu") - self.decoder = self.decoder.to("cpu") - self.lm_head = self.lm_head.to("cpu") - self.model_parallel = False - self.device_map = None - torch.cuda.empty_cache() - - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, new_embeddings): - self.shared = new_embeddings - self.encoder.set_input_embeddings(new_embeddings) - self.decoder.set_input_embeddings(new_embeddings) - - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) - self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def get_output_embeddings(self): - return self.lm_head - - def get_encoder(self): - return self.encoder - - def get_decoder(self): - return self.decoder - - @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.BoolTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - decoder_head_mask: Optional[torch.FloatTensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., - config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for - labels in `[0, ..., config.vocab_size]` - - Returns: - - Examples: - - ```python - >>> from transformers import AutoTokenizer, TimesFMForConditionalGeneration - - >>> tokenizer = AutoTokenizer.from_pretrained("google/timesfm-1.0-200m") - >>> model = TimesFMForConditionalGeneration.from_pretrained("google/timesfm-1.0-200m") - - >>> # training - >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids - >>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids - >>> outputs = model(input_ids=input_ids, labels=labels) - >>> loss = outputs.loss - >>> logits = outputs.logits - - >>> # inference - >>> input_ids = tokenizer( - ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" - ... ).input_ids # Batch size 1 - >>> outputs = model.generate(input_ids) - >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) - >>> # studies have shown that owning a dog is good for you. - ```""" - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask - if head_mask is not None and decoder_head_mask is None: - if self.config.num_layers == self.config.num_decoder_layers: - warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) - decoder_head_mask = head_mask - - # Encode if needed (training, first prediction pass) - if encoder_outputs is None: - # Convert encoder inputs in embeddings if needed - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - - hidden_states = encoder_outputs[0] - - if self.model_parallel: - torch.cuda.set_device(self.decoder.first_device) - - if ( - labels is not None - and decoder_input_ids is None - and decoder_inputs_embeds is None - ): - # get decoder inputs from shifting lm labels to the right - decoder_input_ids = self._shift_right(labels) - - # Set device for model parallelism - if self.model_parallel: - torch.cuda.set_device(self.decoder.first_device) - hidden_states = hidden_states.to(self.decoder.first_device) - if decoder_input_ids is not None: - decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) - if attention_mask is not None: - attention_mask = attention_mask.to(self.decoder.first_device) - if decoder_attention_mask is not None: - decoder_attention_mask = decoder_attention_mask.to( - self.decoder.first_device - ) - - # Decode - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - inputs_embeds=decoder_inputs_embeds, - past_key_values=past_key_values, - encoder_hidden_states=hidden_states, - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = decoder_outputs[0] - - # Set device for model parallelism - if self.model_parallel: - torch.cuda.set_device(self.encoder.first_device) - self.lm_head = self.lm_head.to(self.encoder.first_device) - sequence_output = sequence_output.to(self.lm_head.weight.device) - - if self.config.tie_word_embeddings: - # Rescale output before projecting on vocab - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 - sequence_output = sequence_output * (self.model_dim**-0.5) - - lm_logits = self.lm_head(sequence_output) - - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss(ignore_index=-100) - # move labels to correct device to enable PP - labels = labels.to(lm_logits.device) - loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) - # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 - - if not return_dict: - output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs - return ((loss,) + output) if loss is not None else output - - return Seq2SeqLMOutput( - loss=loss, - logits=lm_logits, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - decoder_attention_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - return { - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "decoder_attention_mask": decoder_attention_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - } - - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): - return self._shift_right(labels) - - def _reorder_cache(self, past_key_values, beam_idx): - # if decoder past is not included in output - # speedy decoding is disabled and no need to reorder - if past_key_values is None: - logger.warning( - "You might want to consider setting `use_cache=True` to speed up decoding" - ) - return past_key_values - - reordered_decoder_past = () - for layer_past_states in past_key_values: - # get the correct batch idx from layer past batch dim - # batch dim of `past` is at 2nd position - reordered_layer_past_states = () - for layer_past_state in layer_past_states: - # need to set correct `past` for each of the four key / value states - reordered_layer_past_states = reordered_layer_past_states + ( - layer_past_state.index_select( - 0, beam_idx.to(layer_past_state.device) - ), - ) - - if reordered_layer_past_states[0].shape != layer_past_states[0].shape: - raise ValueError( - f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched" - ) - if len(reordered_layer_past_states) != len(layer_past_states): - raise ValueError( - f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched" - ) - - reordered_decoder_past = reordered_decoder_past + ( - reordered_layer_past_states, - ) - return reordered_decoder_past From 1ec48c7bb2ff81cb4f9bb7566f3146de550bb67b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 15 Sep 2024 13:55:27 +0200 Subject: [PATCH 092/242] order of imports --- src/transformers/__init__.py | 12 ++++++------ src/transformers/models/__init__.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 1e36fe03fbc1..b46578c461bf 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -6008,7 +6008,6 @@ SwitchTransformersConfig, ) from .models.t5 import T5Config - from .models.timesfm import TimesFMConfig from .models.table_transformer import ( TableTransformerConfig, ) @@ -6020,6 +6019,7 @@ from .models.time_series_transformer import ( TimeSeriesTransformerConfig, ) + from .models.timesfm import TimesFMConfig from .models.timesformer import ( TimesformerConfig, ) @@ -8430,11 +8430,6 @@ T5PreTrainedModel, load_tf_weights_in_t5, ) - from .models.timesfm import ( - TimesFMForPrediction, - TimesFMModel, - TimesFMPreTrainedModel, - ) from .models.table_transformer import ( TableTransformerForObjectDetection, TableTransformerModel, @@ -8459,6 +8454,11 @@ TimeSeriesTransformerModel, TimeSeriesTransformerPreTrainedModel, ) + from .models.timesfm import ( + TimesFMForPrediction, + TimesFMModel, + TimesFMPreTrainedModel, + ) from .models.timesformer import ( TimesformerForVideoClassification, TimesformerModel, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 12957e1622bd..314a3c4ab68e 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -262,11 +262,11 @@ swinv2, switch_transformers, t5, - timesfm, table_transformer, tapas, textnet, time_series_transformer, + timesfm, timesformer, timm_backbone, timm_wrapper, From 8abfc2e60852b279251f0727fa397ad2551a1653 Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Wed, 18 Sep 2024 17:27:31 -0700 Subject: [PATCH 093/242] copy from Google official implementation --- .../models/timesfm/configuration_timesfm.py | 124 +-- ...mesfm_original_tf_checkpoint_to_pytorch.py | 75 -- .../convert_timesfmx_checkpoint_to_flax.py | 299 ------ .../convert_timesfmx_checkpoint_to_pytorch.py | 279 ------ .../models/timesfm/modeling_timesfm.py | 940 +++--------------- .../models/timesfm/patched_decoder.py | 766 ++++++++++++++ .../models/timesfm/timesfm_base.py | 572 +++++++++++ src/transformers/models/timesfm/xreg_lib.py | 520 ++++++++++ 8 files changed, 2063 insertions(+), 1512 deletions(-) delete mode 100644 src/transformers/models/timesfm/convert_timesfm_original_tf_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_flax.py delete mode 100644 src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/timesfm/patched_decoder.py create mode 100644 src/transformers/models/timesfm/timesfm_base.py create mode 100644 src/transformers/models/timesfm/xreg_lib.py diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 16da290cc0cb..de82a874771b 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2020, The TimesFM Authors and HuggingFace Inc. +# Copyright 2024 Google LLC and HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,54 +36,47 @@ class TimesFMConfig(PretrainedConfig): Arguments: patch_len (`int`, *optional*, defaults to 32): - The length of each patch in the sequence. + The length of one patch in the input sequence. horizon_len (`int`, *optional*, defaults to 128): The length of the prediction horizon. - quantiles (`List[float]`, *optional*, defaults to `[0.1, 0.25, 0.5, 0.75, 0.9]`): - The quantiles to predict. - pad_val (`float`, *optional*, defaults to 1123581321.0): - The value used to pad the predictions. - tolerance (`float`, *optional*, defaults to 1e-6): - The tolerance for the quantile loss. + context_len (`int`, *optional*, defaults to 512): + The length of the input context. freq_size (`int`, *optional*, defaults to 3): The number of frequency embeddings. - d_model (`int`, *optional*, defaults to 1280): - Size of the encoder layers and the pooler layer. - d_kv (`int`, *optional*, defaults to 80): + model_dim (`int`, *optional*, defaults to 1280): + Size of the hidden layers in the feed-forward networks. + head_dim (`int`, *optional*, defaults to 80): Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will - be defined as `num_heads * d_kv`. - d_ff (`int`, *optional*, defaults to 1280): - Size of the intermediate feed forward layer in each `TimesFMBlock`. + be defined as `num_heads * head_dim`. num_layers (`int`, *optional*, defaults to 20): - Number of hidden layers in the Transformer encoder. - num_decoder_layers (`int`, *optional*): - Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set. + Number of Transformer layers. num_heads (`int`, *optional*, defaults to 16): Number of attention heads for each attention layer in the Transformer encoder. - relative_attention_num_buckets (`int`, *optional*, defaults to 32): - The number of buckets to use for each attention layer. - relative_attention_max_distance (`int`, *optional*, defaults to 128): - The maximum distance of the longer sequences for the bucket separation. + tolerance (`float`, *optional*, defaults to 1e-6): + The tolerance for the quantile loss. dropout_rate (`float`, *optional*, defaults to 0.1): The ratio for all dropout layers. classifier_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for classifier. - layer_norm_eps (`float`, *optional*, defaults to 1e-6): - The epsilon used by the layer normalization layers. + rms_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the RMS normalization layers. + quantiles (`List[float]`, *optional*, defaults to `[0.1, 0.25, 0.5, 0.75, 0.9]`): + The quantiles to predict. + pad_val (`float`, *optional*, defaults to 1123581321.0): + The value used to pad the predictions. + use_positional_embedding (`bool`, *optional*, defaults to `True`): + Whether to add positional embeddings. + per_core_batch_size (`int`, *optional*, defaults to 32): + The batch size per core for data parallelism. initializer_factor (`float`, *optional*, defaults to 1): A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing). - feed_forward_proj (`string`, *optional*, defaults to `"relu"`): - Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. TimesFMv1.1 uses the - `"gated-gelu"` feed forward projection. Original TimesFM uses `"relu"`. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). """ model_type = "timesfm" - keys_to_ignore_at_inference = ["past_key_values"] + keys_to_ignore_at_inference = [] attribute_map = { - "hidden_size": "d_model", + "hidden_size": "hidden_size", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers", } @@ -91,72 +84,41 @@ class TimesFMConfig(PretrainedConfig): def __init__( self, patch_len: int = 32, + context_len: int = 512, horizon_len: int = 128, - quantiles: List[float] = [0.1, 0.25, 0.5, 0.75, 0.9], - pad_val: float = 1123581321.0, + freq_size: int = 3, + num_layers: int = 20, + model_dim: int = 1280, + head_dim: int = 80, + num_heads: int = 16, + dropout_rate: float = 0.1, tolerance: float = 1e-6, - freq_size=3, - d_model=1280, - d_kv=80, - d_ff=1280, - num_layers=20, - num_decoder_layers=None, - num_heads=16, - relative_attention_num_buckets=32, - relative_attention_max_distance=128, - dropout_rate=0.1, - layer_norm_epsilon=1e-6, - initializer_factor=1.0, - feed_forward_proj="relu", - is_encoder_decoder=True, - use_cache=True, - pad_token_id=0, - eos_token_id=1, - classifier_dropout=0.0, + rms_norm_eps: float = 1e-6, + quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + pad_val: float = 1123581321.0, + use_positional_embedding: bool = True, + per_core_batch_size: int = 32, + initializer_factor: float = 1.0, **kwargs, ): self.patch_len = patch_len + self.context_len = context_len self.horizon_len = horizon_len self.quantiles = quantiles self.pad_val = pad_val - self.tolerance = tolerance self.freq_size = freq_size - self.d_model = d_model - self.d_kv = d_kv - self.d_ff = d_ff + self.model_dim = model_dim + self.head_dim = head_dim self.num_layers = num_layers - self.num_decoder_layers = ( - num_decoder_layers if num_decoder_layers is not None else self.num_layers - ) # default = symmetry self.num_heads = num_heads - self.relative_attention_num_buckets = relative_attention_num_buckets - self.relative_attention_max_distance = relative_attention_max_distance self.dropout_rate = dropout_rate - self.classifier_dropout = classifier_dropout - self.layer_norm_epsilon = layer_norm_epsilon + self.tolerance = tolerance + self.rms_norm_eps = rms_norm_eps + self.use_positional_embedding = use_positional_embedding + self.per_core_batch_size = per_core_batch_size self.initializer_factor = initializer_factor - self.feed_forward_proj = feed_forward_proj - self.use_cache = use_cache - - act_info = self.feed_forward_proj.split("-") - self.dense_act_fn = act_info[-1] - self.is_gated_act = act_info[0] == "gated" - - if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2: - raise ValueError( - f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. " - "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. " - "'gated-gelu' or 'relu'" - ) - - # for backwards compatibility - if feed_forward_proj == "gated-gelu": - self.dense_act_fn = "gelu_new" super().__init__( - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - is_encoder_decoder=is_encoder_decoder, **kwargs, ) diff --git a/src/transformers/models/timesfm/convert_timesfm_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/timesfm/convert_timesfm_original_tf_checkpoint_to_pytorch.py deleted file mode 100644 index b1ce727cac0c..000000000000 --- a/src/transformers/models/timesfm/convert_timesfm_original_tf_checkpoint_to_pytorch.py +++ /dev/null @@ -1,75 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The TimesFM authors and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert TimesFM checkpoint.""" - -import argparse - -from transformers import ( - TimesFMConfig, - TimesFMForConditionalGeneration, - load_tf_weights_in_timesfm, -) -from transformers.utils import logging - - -logging.set_verbosity_info() - - -def convert_tf_checkpoint_to_pytorch( - tf_checkpoint_path, config_file, pytorch_dump_path -): - # Initialise PyTorch model - config = TimesFMConfig.from_json_file(config_file) - print(f"Building PyTorch model from configuration: {config}") - model = TimesFMForConditionalGeneration(config) - - # Load weights from tf checkpoint - load_tf_weights_in_timesfm(model, config, tf_checkpoint_path) - - # Save pytorch-model - print(f"Save PyTorch model to {pytorch_dump_path}") - model.save_pretrained(pytorch_dump_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--tf_checkpoint_path", - default=None, - type=str, - required=True, - help="Path to the TensorFlow checkpoint path.", - ) - parser.add_argument( - "--config_file", - default=None, - type=str, - required=True, - help=( - "The config json file corresponding to the pre-trained TimesFM model. \nThis specifies the model architecture." - ), - ) - parser.add_argument( - "--pytorch_dump_path", - default=None, - type=str, - required=True, - help="Path to the output PyTorch model.", - ) - args = parser.parse_args() - convert_tf_checkpoint_to_pytorch( - args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path - ) diff --git a/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_flax.py b/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_flax.py deleted file mode 100644 index f9468ffb84c6..000000000000 --- a/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_flax.py +++ /dev/null @@ -1,299 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Convert TimesFMX checkpoints from the original repository to JAX/FLAX model.""" - -import argparse - -from timesfmx import checkpoints - -from transformers import FlaxTimesFMForConditionalGeneration, TimesFMConfig - - -def convert_timesfmx_checkpoint_to_flax( - timesfmx_checkpoint_path, config_name, flax_dump_folder_path -): - config = TimesFMConfig.from_pretrained(config_name) - flax_model = FlaxTimesFMForConditionalGeneration(config=config) - timesfmx_model = checkpoints.load_timesfmx_checkpoint(timesfmx_checkpoint_path) - - split_mlp_wi = "wi_0" in timesfmx_model["target"]["encoder"]["layers_0"]["mlp"] - - # Encoder - for layer_index in range(config.num_layers): - layer_name = f"layers_{str(layer_index)}" - - # Self-Attention - timesfmx_attention_key = timesfmx_model["target"]["encoder"][layer_name][ - "attention" - ]["key"]["kernel"] - timesfmx_attention_out = timesfmx_model["target"]["encoder"][layer_name][ - "attention" - ]["out"]["kernel"] - timesfmx_attention_query = timesfmx_model["target"]["encoder"][layer_name][ - "attention" - ]["query"]["kernel"] - timesfmx_attention_value = timesfmx_model["target"]["encoder"][layer_name][ - "attention" - ]["value"]["kernel"] - - # Layer Normalization - timesfmx_attention_layer_norm = timesfmx_model["target"]["encoder"][layer_name][ - "pre_attention_layer_norm" - ]["scale"] - - if split_mlp_wi: - timesfmx_mlp_wi_0 = timesfmx_model["target"]["encoder"][layer_name]["mlp"][ - "wi_0" - ]["kernel"] - timesfmx_mlp_wi_1 = timesfmx_model["target"]["encoder"][layer_name]["mlp"][ - "wi_1" - ]["kernel"] - else: - timesfmx_mlp_wi = timesfmx_model["target"]["encoder"][layer_name]["mlp"][ - "wi" - ]["kernel"] - - timesfmx_mlp_wo = timesfmx_model["target"]["encoder"][layer_name]["mlp"]["wo"][ - "kernel" - ] - - # Layer Normalization - timesfmx_mlp_layer_norm = timesfmx_model["target"]["encoder"][layer_name][ - "pre_mlp_layer_norm" - ]["scale"] - - # Assigning - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ - "SelfAttention" - ]["k"]["kernel"] = timesfmx_attention_key - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ - "SelfAttention" - ]["o"]["kernel"] = timesfmx_attention_out - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ - "SelfAttention" - ]["q"]["kernel"] = timesfmx_attention_query - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ - "SelfAttention" - ]["v"]["kernel"] = timesfmx_attention_value - - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ - "layer_norm" - ]["weight"] = timesfmx_attention_layer_norm - - if split_mlp_wi: - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"][ - "DenseReluDense" - ]["wi_0"]["kernel"] = timesfmx_mlp_wi_0 - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"][ - "DenseReluDense" - ]["wi_1"]["kernel"] = timesfmx_mlp_wi_1 - else: - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"][ - "DenseReluDense" - ]["wi"]["kernel"] = timesfmx_mlp_wi - - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"][ - "DenseReluDense" - ]["wo"]["kernel"] = timesfmx_mlp_wo - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"][ - "layer_norm" - ]["weight"] = timesfmx_mlp_layer_norm - - # Only for layer 0: - timesfmx_encoder_rel_embedding = timesfmx_model["target"]["encoder"]["relpos_bias"][ - "rel_embedding" - ].T - flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"][ - "relative_attention_bias" - ]["embedding"] = timesfmx_encoder_rel_embedding - - # Assigning - timesfmx_encoder_norm = timesfmx_model["target"]["encoder"]["encoder_norm"]["scale"] - flax_model.params["encoder"]["final_layer_norm"]["weight"] = timesfmx_encoder_norm - - # Decoder - for layer_index in range(config.num_decoder_layers): - layer_name = f"layers_{str(layer_index)}" - - # Self-Attention - timesfmx_attention_key = timesfmx_model["target"]["decoder"][layer_name][ - "self_attention" - ]["key"]["kernel"] - timesfmx_attention_out = timesfmx_model["target"]["decoder"][layer_name][ - "self_attention" - ]["out"]["kernel"] - timesfmx_attention_query = timesfmx_model["target"]["decoder"][layer_name][ - "self_attention" - ]["query"]["kernel"] - timesfmx_attention_value = timesfmx_model["target"]["decoder"][layer_name][ - "self_attention" - ]["value"]["kernel"] - - # Layer Normalization - timesfmx_pre_attention_layer_norm = timesfmx_model["target"]["decoder"][ - layer_name - ]["pre_self_attention_layer_norm"]["scale"] - - # Encoder-Decoder-Attention - timesfmx_enc_dec_attention_key = timesfmx_model["target"]["decoder"][ - layer_name - ]["encoder_decoder_attention"]["key"]["kernel"] - timesfmx_enc_dec_attention_out = timesfmx_model["target"]["decoder"][ - layer_name - ]["encoder_decoder_attention"]["out"]["kernel"] - timesfmx_enc_dec_attention_query = timesfmx_model["target"]["decoder"][ - layer_name - ]["encoder_decoder_attention"]["query"]["kernel"] - timesfmx_enc_dec_attention_value = timesfmx_model["target"]["decoder"][ - layer_name - ]["encoder_decoder_attention"]["value"]["kernel"] - - # Layer Normalization - timesfmx_cross_layer_norm = timesfmx_model["target"]["decoder"][layer_name][ - "pre_cross_attention_layer_norm" - ]["scale"] - - # MLP - if split_mlp_wi: - timesfmx_mlp_wi_0 = timesfmx_model["target"]["decoder"][layer_name]["mlp"][ - "wi_0" - ]["kernel"] - timesfmx_mlp_wi_1 = timesfmx_model["target"]["decoder"][layer_name]["mlp"][ - "wi_1" - ]["kernel"] - else: - timesfmx_mlp_wi = timesfmx_model["target"]["decoder"][layer_name]["mlp"][ - "wi" - ]["kernel"] - - timesfmx_mlp_wo = timesfmx_model["target"]["decoder"][layer_name]["mlp"]["wo"][ - "kernel" - ] - - # Layer Normalization - tx5_mlp_layer_norm = timesfmx_model["target"]["decoder"][layer_name][ - "pre_mlp_layer_norm" - ]["scale"] - - # Assigning - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ - "SelfAttention" - ]["k"]["kernel"] = timesfmx_attention_key - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ - "SelfAttention" - ]["o"]["kernel"] = timesfmx_attention_out - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ - "SelfAttention" - ]["q"]["kernel"] = timesfmx_attention_query - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ - "SelfAttention" - ]["v"]["kernel"] = timesfmx_attention_value - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ - "layer_norm" - ]["weight"] = timesfmx_pre_attention_layer_norm - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ - "EncDecAttention" - ]["k"]["kernel"] = timesfmx_enc_dec_attention_key - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ - "EncDecAttention" - ]["o"]["kernel"] = timesfmx_enc_dec_attention_out - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ - "EncDecAttention" - ]["q"]["kernel"] = timesfmx_enc_dec_attention_query - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ - "EncDecAttention" - ]["v"]["kernel"] = timesfmx_enc_dec_attention_value - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ - "layer_norm" - ]["weight"] = timesfmx_cross_layer_norm - - if split_mlp_wi: - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"][ - "DenseReluDense" - ]["wi_0"]["kernel"] = timesfmx_mlp_wi_0 - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"][ - "DenseReluDense" - ]["wi_1"]["kernel"] = timesfmx_mlp_wi_1 - else: - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"][ - "DenseReluDense" - ]["wi"]["kernel"] = timesfmx_mlp_wi - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"][ - "DenseReluDense" - ]["wo"]["kernel"] = timesfmx_mlp_wo - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"][ - "layer_norm" - ]["weight"] = tx5_mlp_layer_norm - - # Decoder Normalization - tx5_decoder_norm = timesfmx_model["target"]["decoder"]["decoder_norm"]["scale"] - flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm - - # Only for layer 0: - timesfmx_decoder_rel_embedding = timesfmx_model["target"]["decoder"]["relpos_bias"][ - "rel_embedding" - ].T - flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"][ - "relative_attention_bias" - ]["embedding"] = timesfmx_decoder_rel_embedding - - # Token Embeddings - tx5_token_embeddings = timesfmx_model["target"]["token_embedder"]["embedding"] - flax_model.params["shared"]["embedding"] = tx5_token_embeddings - - # LM Head (only in v1.1 checkpoints) - if "logits_dense" in timesfmx_model["target"]["decoder"]: - flax_model.params["lm_head"]["kernel"] = timesfmx_model["target"]["decoder"][ - "logits_dense" - ]["kernel"] - - flax_model.save_pretrained(flax_dump_folder_path) - print("TimesFMX Model was sucessfully converted!") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--timesfmx_checkpoint_path", - default=None, - type=str, - required=True, - help="Path the TX5 checkpoint.", - ) - parser.add_argument( - "--config_name", - default=None, - type=str, - required=True, - help="Config name of TimesFM model.", - ) - parser.add_argument( - "--flax_dump_folder_path", - default=None, - type=str, - required=True, - help="Path to the output FLAX model.", - ) - args = parser.parse_args() - convert_timesfmx_checkpoint_to_flax( - args.timesfmx_checkpoint_path, args.config_name, args.flax_dump_folder_path - ) diff --git a/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_pytorch.py b/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_pytorch.py deleted file mode 100644 index 8d5f13535e8d..000000000000 --- a/src/transformers/models/timesfm/convert_timesfmx_checkpoint_to_pytorch.py +++ /dev/null @@ -1,279 +0,0 @@ -# coding=utf-8 -# Copyright 2024 Google LLC and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Convert TimesFMX checkpoint to PyTorch - -Steps: -- Install gsutil according to https://cloud.google.com/storage/docs/gsutil_install -- Get a TimesFMX checkpoint at https://github.com/google-research/timesfmx/blob/main/docs/models.md#timesfm-11-checkpoints Example: - `gsutil -m cp -r gs://timesfm-data/pretrained_models/timesfmx/timesfm_1_1_small $HOME/` -- Create or download a corresponding config for the downloaded model. E.g. for TimesFM v1.1 small, you can use - https://huggingface.co/google/timesfm-v1_1-small/blob/main/config.json -- Convert: - ``` - python3 convert_timesfmx_checkpoint_to_pytorch.py --timesfmx_checkpoint_path=$HOME/timesfm_1_1_small --config_file=config.json\ - --pytorch_dump_path=$HOME/timesfm_1_1_small_pt - ``` -""" - -import argparse -import collections - -import torch -from flax import traverse_util -from timesfmx import checkpoints - -from transformers import ( - TimesFMConfig, - TimesFMEncoderModel, - TimesFMForConditionalGeneration, -) -from transformers.utils import logging - - -logging.set_verbosity_info() - - -def timesfmx_attention_lookup(params, i, prefix, layer_name="attention"): - """Returns the KOQV parameters of (self-)attention. Does not transpose.""" - k = params[f"{prefix}/layers_{i}/{layer_name}/key/kernel"] - o = params[f"{prefix}/layers_{i}/{layer_name}/out/kernel"] - q = params[f"{prefix}/layers_{i}/{layer_name}/query/kernel"] - v = params[f"{prefix}/layers_{i}/{layer_name}/value/kernel"] - return k, o, q, v - - -def timesfmx_mlp_lookup(params, i, prefix, split_mlp_wi=False): - """Returns the MLP parameters of a layer. Does not transpose.""" - if split_mlp_wi: - wi_0 = params[f"{prefix}/layers_{i}/mlp/wi_0/kernel"] - wi_1 = params[f"{prefix}/layers_{i}/mlp/wi_1/kernel"] - wi = (wi_0, wi_1) - else: - wi = params[f"{prefix}/layers_{i}/mlp/wi/kernel"] - - wo = params[f"{prefix}/layers_{i}/mlp/wo/kernel"] - return wi, wo - - -def timesfmx_layer_norm_lookup(params, i, prefix, layer_name): - """Returns the layer norm param of a layer.""" - return params[f"{prefix}/layers_{i}/{layer_name}/scale"] - - -def convert_timesfmx_to_pytorch( - variables: dict, *, num_layers: int, num_decoder_layers: int, is_encoder_only: bool -): - """Converts the parameters from TimesFMX-Flax to Transformers-PyTorch.""" - old = traverse_util.flatten_dict(variables["target"]) - old = {"/".join(k): v for k, v in old.items()} - - # v1.1 models have a gated GeLU with wi_0 and wi_1 instead of wi - split_mlp_wi = "encoder/layers_0/mlp/wi_0/kernel" in old - print("Split MLP:", split_mlp_wi) - - new = collections.OrderedDict() - - # Shared embeddings. - new["shared.weight"] = old["token_embedder/embedding"] - - # Encoder. - for i in range(num_layers): - # Block i, layer 0 (Self Attention). - layer_norm = timesfmx_layer_norm_lookup( - old, i, "encoder", "pre_attention_layer_norm" - ) - k, o, q, v = timesfmx_attention_lookup(old, i, "encoder", "attention") - new[f"encoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm - new[f"encoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T - new[f"encoder.block.{i}.layer.0.SelfAttention.o.weight"] = o.T - new[f"encoder.block.{i}.layer.0.SelfAttention.q.weight"] = q.T - new[f"encoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T - - # Block i, layer 1 (MLP). - layer_norm = timesfmx_layer_norm_lookup(old, i, "encoder", "pre_mlp_layer_norm") - wi, wo = timesfmx_mlp_lookup(old, i, "encoder", split_mlp_wi) - new[f"encoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm - if split_mlp_wi: - new[f"encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight"] = wi[0].T - new[f"encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight"] = wi[1].T - else: - new[f"encoder.block.{i}.layer.1.DenseReluDense.wi.weight"] = wi.T - new[f"encoder.block.{i}.layer.1.DenseReluDense.wo.weight"] = wo.T - - new["encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = old[ - "encoder/relpos_bias/rel_embedding" - ].T - new["encoder.final_layer_norm.weight"] = old["encoder/encoder_norm/scale"] - - if not is_encoder_only: - # Decoder. - for i in range(num_decoder_layers): - # Block i, layer 0 (Self Attention). - layer_norm = timesfmx_layer_norm_lookup( - old, i, "decoder", "pre_self_attention_layer_norm" - ) - k, o, q, v = timesfmx_attention_lookup(old, i, "decoder", "self_attention") - new[f"decoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm - new[f"decoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T - new[f"decoder.block.{i}.layer.0.SelfAttention.o.weight"] = o.T - new[f"decoder.block.{i}.layer.0.SelfAttention.q.weight"] = q.T - new[f"decoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T - - # Block i, layer 1 (Cross Attention). - layer_norm = timesfmx_layer_norm_lookup( - old, i, "decoder", "pre_cross_attention_layer_norm" - ) - k, o, q, v = timesfmx_attention_lookup( - old, i, "decoder", "encoder_decoder_attention" - ) - new[f"decoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm - new[f"decoder.block.{i}.layer.1.EncDecAttention.k.weight"] = k.T - new[f"decoder.block.{i}.layer.1.EncDecAttention.o.weight"] = o.T - new[f"decoder.block.{i}.layer.1.EncDecAttention.q.weight"] = q.T - new[f"decoder.block.{i}.layer.1.EncDecAttention.v.weight"] = v.T - - # Block i, layer 2 (MLP). - layer_norm = timesfmx_layer_norm_lookup( - old, i, "decoder", "pre_mlp_layer_norm" - ) - wi, wo = timesfmx_mlp_lookup(old, i, "decoder", split_mlp_wi) - new[f"decoder.block.{i}.layer.2.layer_norm.weight"] = layer_norm - if split_mlp_wi: - new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight"] = wi[0].T - new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight"] = wi[1].T - else: - new[f"decoder.block.{i}.layer.2.DenseReluDense.wi.weight"] = wi.T - new[f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight"] = wo.T - - new["decoder.final_layer_norm.weight"] = old["decoder/decoder_norm/scale"] - new["decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = ( - old["decoder/relpos_bias/rel_embedding"].T - ) - - # LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead) - if "decoder/logits_dense/kernel" in old: - new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T - - return new - - -def make_state_dict(converted_params, is_encoder_only: bool): - """Prepares a state dict for the PyTorch model.""" - # Make a state dict with torch tensors. - state_dict = collections.OrderedDict( - [(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()] - ) - - # Add what is missing. - if "encoder.embed_tokens.weight" not in state_dict: - state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"] - - if not is_encoder_only: - if "decoder.embed_tokens.weight" not in state_dict: - state_dict["decoder.embed_tokens.weight"] = state_dict["shared.weight"] - - if "lm_head.weight" not in state_dict: # For old 1.0 models. - print("Using shared word embeddings as lm_head.") - state_dict["lm_head.weight"] = state_dict["shared.weight"] - - return state_dict - - -def load_timesfmx_weights_in_timesfm( - model, config, timesfmx_checkpoint_path, is_encoder_only -): - """Replaces the params in model witht the TimesFMX converted params.""" - variables = checkpoints.load_timesfmx_checkpoint(timesfmx_checkpoint_path) - converted = convert_timesfmx_to_pytorch( - variables, - num_layers=config.num_layers, - num_decoder_layers=config.num_decoder_layers, - is_encoder_only=is_encoder_only, - ) - state_dict = make_state_dict(converted, is_encoder_only) - model.load_state_dict(state_dict, strict=True) - - -def convert_timesfmx_checkpoint_to_pytorch( - timesfmx_checkpoint_path, - config_file, - pytorch_dump_path, - is_encoder_only: bool = False, -): - """Loads the config and model, converts the TimesFMX checkpoint, and saves a PyTorch checkpoint.""" - # Initialise PyTorch model - config = TimesFMConfig.from_json_file(config_file) - print(f"Building PyTorch model from configuration: {config}") - # Non-v1.1 checkpoints could also use TimesFMModel, but this works for all. - # The v1.0 checkpoints will simply have an LM head that is the word embeddings. - if is_encoder_only: - model = TimesFMEncoderModel(config) - else: - model = TimesFMForConditionalGeneration(config) - - # Load weights from tf checkpoint - load_timesfmx_weights_in_timesfm( - model, config, timesfmx_checkpoint_path, is_encoder_only - ) - - # Save pytorch-model - print(f"Save PyTorch model to {pytorch_dump_path}") - model.save_pretrained(pytorch_dump_path) - - # Verify that we can load the checkpoint. - model.from_pretrained(pytorch_dump_path) - print("Done") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Converts a native TimesFMX checkpoint into a PyTorch checkpoint." - ) - # Required parameters - parser.add_argument( - "--timesfmx_checkpoint_path", - default=None, - type=str, - required=True, - help="Path to the TimesFMX checkpoint.", - ) - parser.add_argument( - "--config_file", - default=None, - type=str, - required=True, - help="The config json file corresponding to the pre-trained TimesFM model.\nThis specifies the model architecture.", - ) - parser.add_argument( - "--pytorch_dump_path", - default=None, - type=str, - required=True, - help="Path to the output PyTorch model.", - ) - parser.add_argument( - "--is_encoder_only", - action="store_true", - help="Check if the model is encoder-decoder model", - default=False, - ) - args = parser.parse_args() - convert_timesfmx_checkpoint_to_pytorch( - args.timesfmx_checkpoint_path, - args.config_file, - args.pytorch_dump_path, - args.is_encoder_only, - ) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 852d91320889..ea27c1e75b8c 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2024 Mesh TensorFlow authors, TimesFM Authors and HuggingFace Inc. team. +# Copyright 2024 Google LLC and HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,805 +14,189 @@ # limitations under the License. """PyTorch TimesFM model.""" -import copy -import math -import os -import warnings -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from ...activations import ACT2FN -from ...modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPastAndCrossAttentions, - Seq2SeqLMOutput, - Seq2SeqModelOutput, -) -from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import ( - ALL_LAYERNORM_LAYERS, - find_pruneable_heads_and_indices, - prune_linear_layer, -) -from ...utils import ( - DUMMY_INPUTS, - DUMMY_MASK, - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_torch_fx_proxy, - logging, - replace_return_docstrings, -) -from ...utils.model_parallel_utils import assert_device_map, get_device_map -from .configuration_timesfm import TimesFMConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "TimesFMConfig" -_CHECKPOINT_FOR_DOC = "google/timesfm-1.0-200m" - #################################################### # PyTorch Models are constructed by sub-classing # - torch.nn.Module for the layers and # - PreTrainedModel for the models (it-self a sub-class of nn.Module) #################################################### -PARALLELIZE_DOCSTRING = r""" - This is an experimental feature and is a subject to change at a moment's notice. - - Uses a device map to distribute attention modules of the model across several devices. If no device map is given, - it will evenly distribute blocks across all devices. - - Args: - device_map (`Dict[int, list]`, *optional*): - A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always - automatically mapped to the first device (for esoteric reasons). That means that the first device should - have fewer attention modules mapped to it than other devices. For reference, the timesfm models have the - following number of attention modules: - - - google/timesfm-1.0-200m: 6 - - google-timesfm/timesfm-base: 12 - - google-timesfm/timesfm-large: 24 - - google-timesfm/timesfm-3b: 24 - - google-timesfm/timesfm-11b: 24 - - Example: - - ```python - # Here is an example of a device map on a machine with 4 GPUs using google-timesfm/timesfm-3b, which has a total of 24 attention modules: - model = TimesFMForConditionalGeneration.from_pretrained("google-timesfm/timesfm-3b") - device_map = { - 0: [0, 1, 2], - 1: [3, 4, 5, 6, 7, 8, 9], - 2: [10, 11, 12, 13, 14, 15, 16], - 3: [17, 18, 19, 20, 21, 22, 23], - } - model.parallelize(device_map) - ``` -""" -DEPARALLELIZE_DOCSTRING = r""" - Moves the model to cpu from a model parallel state. - - Example: - - ```python - # On a 4 GPU machine with google-timesfm/timesfm-3b: - model = TimesFMForConditionalGeneration.from_pretrained("google-timesfm/timesfm-3b") - device_map = { - 0: [0, 1, 2], - 1: [3, 4, 5, 6, 7, 8, 9], - 2: [10, 11, 12, 13, 14, 15, 16], - 3: [17, 18, 19, 20, 21, 22, 23], - } - model.parallelize(device_map) # Splits the model across several devices - model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() - ``` -""" - - -# Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->TimesFM -class TimesFMLayerNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Construct a layernorm module in the TimesFM style. No bias and no subtraction of mean. - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - # TimesFM uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean - # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated - # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for - # half-precision inputs is done in fp32 - - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) +import logging +from os import path +from typing import Any, Sequence - return self.weight * hidden_states - - -try: - from apex.normalization import FusedRMSNorm - - TimesFMLayerNorm = FusedRMSNorm # noqa - - logger.info( - "Discovered apex.normalization.FusedRMSNorm - will use it instead of TimesFMLayerNorm" - ) -except ImportError: - # using the normal TimesFMLayerNorm - pass -except Exception: - logger.warning( - "discovered apex but it failed to load, falling back to TimesFMLayerNorm" - ) - pass - -ALL_LAYERNORM_LAYERS.append(TimesFMLayerNorm) - - -class TimesFMResidualBlock(nn.Module): - def __init__(self, input_dims, hidden_dims, output_dims, dropout=0.1): - super().__init__() - - self.hidden_layer = nn.Sequential(nn.Linear(input_dims, hidden_dims), nn.SiLU()) - self.output_layer = nn.Linear(hidden_dims, output_dims) - self.residual_layer = nn.Linear(input_dims, output_dims) - self.dropout = nn.Dropout(dropout) - - def forward(self, inputs): - hidden = self.hidden_layer(inputs) - output = self.output_layer(hidden) - output = self.dropout(output) - residual = self.residual_layer(inputs) - - return output + residual - - -class TimesFMPositionalEmbedding(nn.Module): - """Generates position embedding for a given 1-d sequence. - - Attributes: - min_timescale: Start of the geometric index. Determines the periodicity of - the added signal. - max_timescale: End of the geometric index. Determines the frequency of the - added signal. - embedding_dims: Dimension of the embedding to be generated. - """ +import numpy as np +import torch +from huggingface_hub import snapshot_download +import timesfm_base +import patched_decoder as ppd +from ...modeling_utils import PreTrainedModel - def __init__(self, min_timescale=1, max_timescale=10000, embedding_dims=0): - super().__init__() - self.min_timescale = min_timescale - self.max_timescale = max_timescale - self.embedding_dims = embedding_dims +_TOL = 1e-6 - def forward(self, seq_length=None, position=None): - """Generates a tensor of sinusoids with different frequencies. - Args: - seq_length: an optional Python int defining the output sequence length. - if the `position` argument is specified. - position: [B, seq_length], optional position for each token in the - sequence, only required when the sequence is packed. +class TimesFmTorch(PreTrainedModel, timesfm_base.TimesFmBase): + """TimesFM forecast API for inference.""" - Returns: - [B, seqlen, D] if `position` is specified, else [1, seqlen, D] - """ - if position is None: - if seq_length is None: - raise ValueError("If position is None, seq_length should be specified.") - # [1, seqlen] - position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0) - else: - if position.ndim != 2: - raise ValueError( - f"position should have 2 dimensions, got {position.ndim}" - ) - - num_timescales = self.embedding_dims // 2 - log_timescale_increment = math.log( - float(self.max_timescale) / float(self.min_timescale) - ) / max(torch.tensor(num_timescales, dtype=torch.float32) - 1, 1) - inv_timescales = self.min_timescale * torch.exp( - torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment + def __post_init__(self): + self._model_config = ppd.TimesFMConfig( + num_layers=self.num_layers, + num_heads=self.num_heads, + hidden_size=self.model_dims, + intermediate_size=self.model_dims, + patch_len=self.input_patch_len, + horizon_len=self.output_patch_len, + head_dim=self.model_dims // self.num_heads, + quantiles=self.quantiles, ) - scaled_time = position.unsqueeze(2) * inv_timescales.unsqueeze(0).unsqueeze(0) - signal = torch.cat( - [torch.sin(scaled_time), torch.cos(scaled_time)], dim=2 - ).type(torch.float32) - - signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) - return signal - - -class TimesFMDenseActDense(nn.Module): - def __init__(self, config: TimesFMConfig): - super().__init__() - self.wi = nn.Linear(config.d_model, config.d_ff, bias=True) - self.wo = nn.Linear(config.d_ff, config.d_model, bias=True) - self.dropout = nn.Dropout(config.dropout_rate) - self.act = ACT2FN[config.dense_act_fn] - - def forward(self, hidden_states): - hidden_states = self.wi(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.dropout(hidden_states) - if ( - isinstance(self.wo.weight, torch.Tensor) - and hidden_states.dtype != self.wo.weight.dtype - and self.wo.weight.dtype != torch.int8 - ): - hidden_states = hidden_states.to(self.wo.weight.dtype) - hidden_states = self.wo(hidden_states) - return hidden_states - - -class TimesFMLayerFF(nn.Module): - def __init__(self, config: TimesFMConfig): - super().__init__() - - self.DenseReluDense = TimesFMDenseActDense(config) - self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon) - self.dropout = nn.Dropout(config.dropout_rate) - - def forward(self, hidden_states): - forwarded_states = self.layer_norm(hidden_states) - forwarded_states = self.DenseReluDense(forwarded_states) - hidden_states = hidden_states + self.dropout(forwarded_states) - return hidden_states - - -class TimesFMPerHeadDimScale(nn.Module): - def __init__(self, config: TimesFMConfig): - super().__init__() - dim = config.d_model // config.num_heads - r_softplus_0 = 1.442695041 - self.scale_factor = r_softplus_0 / math.sqrt(dim) - self.scale = nn.Parameter(torch.empty(self.dim)) - - def forward(self, hidden_states): - scale = self.scale_factor * F.softplus(self.scale) - return hidden_states * scale - - -class TimesFMAttention(nn.Module): - def __init__(self, config: TimesFMConfig, has_relative_attention_bias=False): - super().__init__() - self.is_decoder = config.is_decoder - self.has_relative_attention_bias = has_relative_attention_bias - self.relative_attention_num_buckets = config.relative_attention_num_buckets - self.relative_attention_max_distance = config.relative_attention_max_distance - self.d_model = config.d_model - self.key_value_proj_dim = config.d_kv - self.n_heads = config.num_heads - self.dropout = config.dropout_rate - self.inner_dim = self.n_heads * self.key_value_proj_dim - - # Mesh TensorFlow initialization to avoid scaling before softmax - self.q = nn.Linear(self.d_model, self.inner_dim, bias=True) - self.k = nn.Linear(self.d_model, self.inner_dim, bias=True) - self.v = nn.Linear(self.d_model, self.inner_dim, bias=True) - self.o = nn.Linear(self.inner_dim, self.d_model, bias=True) - self.per_head_dim_scale = TimesFMPerHeadDimScale(config) - - if self.has_relative_attention_bias: - self.relative_attention_bias = nn.Embedding( - self.relative_attention_num_buckets, self.n_heads - ) - self.pruned_heads = set() - self.gradient_checkpointing = False - - def prune_heads(self, heads): - if len(heads) == 0: - return - heads, index = find_pruneable_heads_and_indices( - heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + self._model = None + self.num_cores = 1 + self.global_batch_size = self.per_core_batch_size + self._device = torch.device( + "cuda:0" if (torch.cuda.is_available() and self.backend == "gpu") else "cpu" ) - # Prune linear layers - self.q = prune_linear_layer(self.q, index) - self.k = prune_linear_layer(self.k, index) - self.v = prune_linear_layer(self.v, index) - self.o = prune_linear_layer(self.o, index, dim=1) - # Update hyper params - self.n_heads = self.n_heads - len(heads) - self.inner_dim = self.key_value_proj_dim * self.n_heads - self.pruned_heads = self.pruned_heads.union(heads) - - @staticmethod - def _relative_position_bucket( - relative_position, bidirectional=True, num_buckets=32, max_distance=128 - ): - """ - Adapted from Mesh Tensorflow: - https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 - Translate relative position to a bucket number for relative attention. The relative position is defined as - memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for - small absolute relative_position and larger buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. - This should allow for more graceful generalization to longer sequences than the model has been trained on + def load_from_checkpoint( + self, + checkpoint: timesfm_base.TimesFmCheckpoint, + ) -> None: + """Loads a checkpoint and compiles the decoder.""" + checkpoint_path = checkpoint.path + repo_id = checkpoint.huggingface_repo_id + if checkpoint_path is None: + checkpoint_path = path.join(snapshot_download(repo_id), "torch_model.ckpt") + self._model = ppd.PatchedTimeSeriesDecoder(self._model_config) + loaded_checkpoint = torch.load(checkpoint_path, weights_only=True) + logging.info("Loading checkpoint from %s", checkpoint_path) + self._model.load_state_dict(loaded_checkpoint) + logging.info("Sending checkpoint to device %s", f"{self._device}") + self._model.to(self._device) + self._model.eval() + # TODO: add compilation. + + def forecast( + self, + inputs: Sequence[Any], + freq: Sequence[int] | None = None, + window_size: int | None = None, + forecast_context_len: int | None = None, + return_forecast_on_context: bool = False, + truncate_negative: bool = False, + ) -> tuple[np.ndarray, np.ndarray]: + """Forecasts on a list of time series. Args: - relative_position: an int32 Tensor - bidirectional: a boolean - whether the attention is bidirectional - num_buckets: an integer - max_distance: an integer + inputs: list of time series forecast contexts. Each context time series + should be in a format convertible to JTensor by `jnp.array`. + freq: frequency of each context time series. 0 for high frequency + (default), 1 for medium, and 2 for low. Notice this is different from + the `freq` required by `forecast_on_df`. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + forecast_context_len: optional max context length. + return_forecast_on_context: True to return the forecast on the context + when available, i.e. after the first input patch. + truncate_negative: truncate to only non-negative values if all the contexts + have non-negative values. Returns: - a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) - """ - relative_buckets = 0 - if bidirectional: - num_buckets //= 2 - relative_buckets += (relative_position > 0).to(torch.long) * num_buckets - relative_position = torch.abs(relative_position) - else: - relative_position = -torch.min( - relative_position, torch.zeros_like(relative_position) - ) - # now relative_position is in the range [0, inf) - - # half of the buckets are for exact increments in positions - max_exact = num_buckets // 2 - is_small = relative_position < max_exact - - # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance - relative_position_if_large = max_exact + ( - torch.log(relative_position.float() / max_exact) - / math.log(max_distance / max_exact) - * (num_buckets - max_exact) - ).to(torch.long) - relative_position_if_large = torch.min( - relative_position_if_large, - torch.full_like(relative_position_if_large, num_buckets - 1), - ) - - relative_buckets += torch.where( - is_small, relative_position, relative_position_if_large - ) - return relative_buckets + A tuple for JTensors: + - the mean forecast of size (# inputs, # forecast horizon), + - the full forecast (mean + quantiles) of size + (# inputs, # forecast horizon, 1 + # quantiles). - def compute_bias(self, query_length, key_length, device=None): - """Compute binned relative position bias""" - if device is None: - device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[ - :, None - ] - memory_position = torch.arange(key_length, dtype=torch.long, device=device)[ - None, : - ] - relative_position = ( - memory_position - context_position - ) # shape (query_length, key_length) - relative_position_bucket = self._relative_position_bucket( - relative_position, # shape (query_length, key_length) - bidirectional=(not self.is_decoder), - num_buckets=self.relative_attention_num_buckets, - max_distance=self.relative_attention_max_distance, - ) - values = self.relative_attention_bias( - relative_position_bucket - ) # shape (query_length, key_length, num_heads) - values = values.permute([2, 0, 1]).unsqueeze( - 0 - ) # shape (1, num_heads, query_length, key_length) - return values - - def forward( - self, - hidden_states, - mask=None, - key_value_states=None, - position_bias=None, - past_key_value=None, - layer_head_mask=None, - query_length=None, - use_cache=False, - output_attentions=False, - ): + Raises: + ValueError: If the checkpoint is not properly loaded. """ - Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). - """ - # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) - batch_size, seq_length = hidden_states.shape[:2] - - real_seq_length = seq_length - - if past_key_value is not None: - if len(past_key_value) != 2: - raise ValueError( - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - ) - real_seq_length += ( - past_key_value[0].shape[2] if query_length is None else query_length + if not self._model: + raise ValueError( + "Checkpoint not loaded. Call `load_from_checkpoint` before" + " `forecast`." ) - - key_length = ( - real_seq_length if key_value_states is None else key_value_states.shape[1] - ) - - def shape(states): - """projection""" - return states.view( - batch_size, -1, self.n_heads, self.key_value_proj_dim - ).transpose(1, 2) - - def unshape(states): - """reshape""" - return ( - states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - ) - - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - unscaled_query_states = shape( - self.q(hidden_states) - ) # (batch_size, n_heads, seq_length, dim_per_head) - query_states = self.per_head_dim_scale(unscaled_query_states) - - # get key/value states - key_states = project( - hidden_states, - self.k, - key_value_states, - past_key_value[0] if past_key_value is not None else None, - ) - value_states = project( - hidden_states, - self.v, - key_value_states, - past_key_value[1] if past_key_value is not None else None, - ) - - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 - - if position_bias is None: - if not self.has_relative_attention_bias: - position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), - device=scores.device, - dtype=scores.dtype, + if forecast_context_len is None: + fcontext_len = self.context_len + else: + fcontext_len = forecast_context_len + inputs = [np.array(ts)[-fcontext_len:] for ts in inputs] + inp_min = np.min([np.min(ts) for ts in inputs]) + + if window_size is not None: + new_inputs = [] + for ts in inputs: + new_inputs.extend(timesfm_base.moving_average(ts, window_size)) + inputs = new_inputs + + if freq is None: + logging.info("No frequency provided via `freq`. Default to high (0).") + freq = [0] * len(inputs) + + input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq) + with torch.no_grad(): + mean_outputs = [] + full_outputs = [] + assert input_ts.shape[0] % self.global_batch_size == 0 + for i in range(input_ts.shape[0] // self.global_batch_size): + input_ts_in = torch.from_numpy( + np.array( + input_ts[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size + ], + dtype=np.float32, + ) + ).to(self._device) + input_padding_in = torch.from_numpy( + np.array( + input_padding[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size + ], + dtype=np.float32, + ) + ).to(self._device) + inp_freq_in = ( + torch.from_numpy( + np.array( + inp_freq[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size, + :, + ], + dtype=np.int32, + ) + ) + .long() + .to(self._device) ) - if self.gradient_checkpointing and self.training: - position_bias.requires_grad = True - else: - position_bias = self.compute_bias( - real_seq_length, key_length, device=scores.device + mean_output, full_output = self._model.decode( + input_ts=input_ts_in, + paddings=input_padding_in, + freq=inp_freq_in, + horizon_len=self.horizon_len, + return_forecast_on_context=return_forecast_on_context, ) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] - - if mask is not None: - position_bias = ( - position_bias + mask - ) # (batch_size, n_heads, seq_length, key_length) - - if self.pruned_heads: - mask = torch.ones(position_bias.shape[1]) - mask[list(self.pruned_heads)] = 0 - position_bias_masked = position_bias[:, mask.bool()] - else: - position_bias_masked = position_bias - - scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) - - # Mask heads if we want to - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask - - attn_output = unshape( - torch.matmul(attn_weights, value_states) - ) # (batch_size, seq_length, dim) - attn_output = self.o(attn_output) - - present_key_value_state = ( - (key_states, value_states) if (self.is_decoder and use_cache) else None - ) - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) - - if output_attentions: - outputs = outputs + (attn_weights,) - return outputs - - -class TimesFMTransformerLayer(nn.Module): - def __init__(self, config: TimesFMConfig): - super().__init__() - self.attention = TimesFMAttention(config) - self.ff = TimesFMLayerFF(config) - self.layer_norm = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) - self.dropout = nn.Dropout(config.dropout_rate) - - def forward(self, inputs, mask=None): - x = self.layer_norm(inputs) - x = self.attention(x, mask=mask) - x = self.dropout(x) - x = x + inputs - x = self.ff(x) - return x - - -class TimesFMTransformerStack(nn.Module): - def __init__(self, config: TimesFMConfig): - super().__init__() - self.layers = nn.ModuleList( - [TimesFMTransformerLayer(config) for _ in range(config.num_layers)] - ) - - def forward(self, hidden_states, mask=None): - for layer in self.layers: - hidden_states = layer(hidden_states, mask=mask) - return hidden_states - - -class TimesFMModel(PreTrainedModel): - def __init__(self, config: TimesFMConfig): - super().__init__(config) - - self.freq_emb = nn.Embedding( - num_embeddings=config.freq_size, - embedding_dim=config.d_model, - ) - self.position_emb = TimesFMPositionalEmbedding( - embedding_dims=config.d_model, - ) - - self.input_ff_layer = TimesFMResidualBlock( - input_dims=config.patch_len * 2, - output_dims=config.d_model, - hidden_dims=config.d_ff, - dropout=config.dropout_rate, - ) - - self.stacked_transformer_layer = TimesFMTransformerStack(config) - - def preprocess_inputs(self, inputs): - assert len(inputs.shape) == 3 # (batch_size, num_patches, patch_len) - inputs_mean = inputs.mean(dim=(1, 2)) - inputs_std = inputs.std(dim=(1, 2)) - processed_input = (inputs - inputs_mean[:, None, None]) / inputs_std[ - :, None, None - ] - return processed_input, (inputs_mean, inputs_std) - - def create_causal_mask(batch_size, seq_len): - mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool() - mask = mask.unsqueeze(0).unsqueeze(1) - mask = mask.expand(batch_size, 1, seq_len, seq_len) - mask = mask.float().masked_fill(mask, -2.3819763e38).masked_fill(~mask, 0.0) - return mask - - def forward( - self, - input_ts, - ): - batch_size = input_ts.shape[0] - patched_inputs = input_ts.reshape(batch_size, -1, self.config.patch_len) - patched_pads = torch.zeros_like(patched_inputs) - patched_inputs, input_stats = self.preprocess_inputs(patched_inputs) - concat_inputs = torch.concat([patched_inputs, patched_pads], dim=-1) - - model_input = self.input_ff_layer(concat_inputs) - position_emb = self.position_emb(seq_length=model_input.shape[1]).expand( - model_input.shape[0], -1, -1 - ) - model_input = model_input + position_emb - f_emb = self.freq_emb( - torch.zeros((batch_size, 1), dtype=torch.long) - ) # freq set to zero, change if needed - model_input = model_input + f_emb - mask = self.create_causal_mask(model_input.shape[0], model_input.shape[1]) - model_output = self.stacked_transformer_layer(model_input, mask=mask) - return model_output, input_stats - - -class TimesFMPredictionHead(nn.Module): - def __init__(self, config: TimesFMConfig): - super().__init__() - self.config = config - self.horizon_ff_layer = TimesFMResidualBlock( - input_dims=config.d_model, - output_dims=config.horizon_len, - hidden_dims=config.d_ff, - dropout=config.dropout_rate, - ) - - def postprocess_outputs(self, outputs, stats): - mean, std = stats - return outputs * std[:, None, None, None] + mean[:, None, None, None] - - def forward(self, model_output, input_stats): - batch_size = model_output.shape[0] - output_ts = self.horizon_ff_layer(model_output) - - assert self.config.d_model % self.config.horizon_len == 0 - num_outputs = self.config.d_model // self.config.horizon_len - - output_ts = output_ts.reshape( - batch_size, -1, self.config.horizon_len, num_outputs - ) - output_ts = self.postprocess_outputs(output_ts, input_stats) - return output_ts - - -class TimesFMForPrediction(PreTrainedModel): - def __init__(self, config: TimesFMConfig): - super().__init__(config) - self.timesfm = TimesFMModel(config) - self.prediction_head = TimesFMPredictionHead(config) - - def forward( - self, - input_ts, - ): - model_output, input_stats = self.timesfm(input_ts) - output_ts = self.prediction_head(model_output, input_stats) - return output_ts - - -TIMESFM_START_DOCSTRING = r""" - - The TIMESFM model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text - Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan - Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a - text-to-text denoising generative setting. - - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`TimesFMConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -TIMESFM_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. TIMESFM is a model with relative position embeddings so you - should be able to pad the inputs on both the right and the left. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for detail. - - [What are input IDs?](../glossary#input-ids) - - To know more on how to prepare `input_ids` for pretraining take a look a [TIMESFM Training](./timesfm#training). - attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - TIMESFM uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` - is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). - - To know more on how to prepare `decoder_input_ids` for pretraining take a look at [TIMESFM - Training](./timesfm#training). - decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, - 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, - 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in - `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): - Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at - the output of the last layer of the encoder. Used in the cross-attention of the decoder. - past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded - representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be - input (see `past_key_values`). This is useful if you want more control over how to convert - `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - - If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value - of `inputs_embeds`. - - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" + mean_output = mean_output.detach().cpu().numpy() + full_output = full_output.detach().cpu().numpy() + mean_output = np.array(mean_output) + full_output = np.array(full_output) + mean_outputs.append(mean_output) + full_outputs.append(full_output) + + mean_outputs = np.concatenate(mean_outputs, axis=0) + full_outputs = np.concatenate(full_outputs, axis=0) + + if pmap_pad > 0: + mean_outputs = mean_outputs[:-pmap_pad, ...] + full_outputs = full_outputs[:-pmap_pad, ...] + + if window_size is not None: + mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] + full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] + if inp_min >= 0 and truncate_negative: + mean_outputs = np.maximum(mean_outputs, 0.0) + full_outputs = np.maximum(full_outputs, 0.0) + return mean_outputs, full_outputs diff --git a/src/transformers/models/timesfm/patched_decoder.py b/src/transformers/models/timesfm/patched_decoder.py new file mode 100644 index 000000000000..f7e108bc08d8 --- /dev/null +++ b/src/transformers/models/timesfm/patched_decoder.py @@ -0,0 +1,766 @@ +# coding=utf-8 +# Copyright 2024 Google LLC and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pytorch version of patched decoder.""" + + +import math +from typing import List, Tuple +import torch +from torch import nn +import torch.nn.functional as F +from transformers.models.timesfm.configuration_timesfm import TimesFMConfig + + +def _masked_mean_std( + inputs: torch.Tensor, padding: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """Calculates mean and standard deviation of `inputs` across axis 1. + + It excludes values where `padding` is 1. + + Args: + inputs: A PyTorch tensor of shape [b, n, p]. + padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1. + + Returns: + A tuple containing the mean and standard deviation. + We return the statistics of the first patch with more than three non-padded + values. + """ + # Selecting the first patch with more than 3 unpadded values. + pad_sum = torch.sum(1 - padding, dim=2) + + def _get_patch_index(arr: torch.Tensor): + indices = torch.argmax((arr >= 3).to(torch.int32), dim=1) + row_sum = (arr >= 3).to(torch.int32).sum(dim=1) + return torch.where(row_sum == 0, arr.shape[1] - 1, indices) + + patch_indices = _get_patch_index(pad_sum) + bidxs = torch.arange(inputs.shape[0]) + + arr = inputs[bidxs, patch_indices, :] + pad = padding[bidxs, patch_indices, :] + + # Create a mask where padding is 0 + mask = 1 - pad + + # Calculate the number of valid elements + num_valid_elements = torch.sum(mask, dim=1) + num_valid_elements = torch.where( + num_valid_elements == 0, + torch.tensor( + 1, dtype=num_valid_elements.dtype, device=num_valid_elements.device + ), + num_valid_elements, + ) + + # Calculate the masked sum and squared sum + masked_sum = torch.sum(arr * mask, dim=1) + masked_squared_sum = torch.sum((arr * mask) ** 2, dim=1) + + # Calculate the masked mean and standard deviation + masked_mean = masked_sum / num_valid_elements + masked_var = masked_squared_sum / num_valid_elements - masked_mean**2 + masked_var = torch.where( + masked_var < 0.0, + torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device), + masked_var, + ) + masked_std = torch.sqrt(masked_var) + + return masked_mean, masked_std + + +def _shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor: + """Shifts rows of seq based on the first 0 in each row of the mask. + + Args: + mask: mask tensor of shape [B, N] + seq: seq tensor of shape [B, N, P] + + Returns: + Returns the shifted sequence. + """ + batch_size, num_seq, feature_dim = seq.shape + + new_mask: torch.BoolTensor = mask == 0 + + # Use argmax to find the first True value in each row + indices = new_mask.to(torch.int32).argmax(dim=1) + + # Handle rows with all zeros + indices[~new_mask.any(dim=1)] = -1 + + # Create index ranges for each sequence in the batch + idx_range = ( + torch.arange(num_seq) + .to(seq.device) + .unsqueeze(0) + .unsqueeze(-1) + .expand(batch_size, -1, feature_dim) + ) + + # Calculate shifted indices for each element in each sequence + shifted_idx = (idx_range - indices[:, None, None]) % num_seq + + # Gather values from seq using shifted indices + shifted_seq = seq.gather(1, shifted_idx) + + return shifted_seq + + +def get_large_negative_number(dtype: torch.dtype) -> torch.Tensor: + """Returns a large negative value for the given dtype.""" + if dtype.is_floating_point: + dtype_max = torch.finfo(dtype).max + else: + dtype_max = torch.iinfo(dtype).max + return torch.tensor(-0.7 * dtype_max, dtype=dtype) + + +def apply_mask_to_logits(logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Applies a floating-point mask to a set of logits. + + Args: + logits: A torch.Tensor of logit values. + mask: A torch.Tensor (float32) of mask values with the encoding described + in the function documentation. + + Returns: + Masked logits. + """ + + min_value = get_large_negative_number(logits.dtype) + + return torch.where((mask >= min_value * 0.5), logits, min_value) + + +def convert_paddings_to_mask( + paddings: torch.Tensor, dtype: torch.dtype = torch.float32 +) -> torch.Tensor: + """Converts binary paddings to a logit mask ready to add to attention matrix. + + Args: + paddings: binary torch.Tensor of shape [B, T], with 1 denoting padding + token. + dtype: data type of the input. + + Returns: + A torch.Tensor of shape [B, 1, 1, T] ready to add to attention logits. + """ + attention_mask = paddings.detach().clone() + attention_mask = attention_mask[:, None, None, :] # Equivalent to jnp.newaxis + attention_mask *= get_large_negative_number(dtype) + return attention_mask + + +def causal_mask(input_t: torch.Tensor) -> torch.Tensor: + """Computes and returns causal mask. + + Args: + input_t: A torch.Tensor of shape [B, T, D]. + + Returns: + An attention_mask torch.Tensor of shape [1, 1, T, T]. Attention mask has + already been converted to large negative values. + """ + assert input_t.dtype.is_floating_point, input_t.dtype + large_negative_number = get_large_negative_number(input_t.dtype) + t = input_t.shape[1] + col_idx = torch.arange(t).unsqueeze(0).repeat(t, 1) + row_idx = torch.arange(t).unsqueeze(1).repeat(1, t) + mask = (row_idx < col_idx).to(input_t.dtype) * large_negative_number + return ( + mask.unsqueeze(0).unsqueeze(0).to(input_t.device) + ) # Equivalent to jnp.newaxis + + +def merge_masks(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """Merges 2 masks. + + logscale mask is expected but 0/1 mask is also fine. + + Args: + a: torch.Tensor of shape [1|B, 1, 1|T, S]. + b: torch.Tensor of shape [1|B, 1, 1|T, S]. + + Returns: + torch.Tensor of shape [1|B, 1, 1|T, S]. + """ + + def expand_t(key_mask): + query_mask = key_mask.transpose(-1, -2) # Equivalent of jnp.transpose + return torch.minimum(query_mask, key_mask) + + if a.shape[2] != b.shape[2]: + if a.shape[2] == 1: + a = expand_t(a) + else: + assert b.shape[2] == 1 + b = expand_t(b) + + assert a.shape[1:] == b.shape[1:], f"a.shape={a.shape}, b.shape={b.shape}." + return torch.minimum(a, b) # Element-wise minimum, similar to jnp.minimum + + +class ResidualBlock(nn.Module): + """TimesFM residual block.""" + + def __init__( + self, + input_dims, + hidden_dims, + output_dims, + ): + super(ResidualBlock, self).__init__() + self.input_dims = input_dims + self.hidden_dims = hidden_dims + self.output_dims = output_dims + + # Hidden Layer + self.hidden_layer = nn.Sequential( + nn.Linear(input_dims, hidden_dims), + nn.SiLU(), + ) + + # Output Layer + self.output_layer = nn.Linear(hidden_dims, output_dims) + # Residual Layer + self.residual_layer = nn.Linear(input_dims, output_dims) + + def forward(self, x): + hidden = self.hidden_layer(x) + output = self.output_layer(hidden) + residual = self.residual_layer(x) + return output + residual + + +class RMSNorm(torch.nn.Module): + """Pax rms norm in pytorch.""" + + def __init__( + self, + dim: int, + eps: float = 1e-6, + add_unit_offset: bool = False, + ): + super().__init__() + self.eps = eps + self.add_unit_offset = add_unit_offset + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + if self.add_unit_offset: + output = output * (1 + self.weight.float()) + else: + output = output * self.weight.float() + return output.type_as(x) + + +class TransformerMLP(nn.Module): + """Pax transformer MLP in pytorch.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + ): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size) + self.down_proj = nn.Linear(intermediate_size, hidden_size) + self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6) + + def forward(self, x, paddings=None): + gate_inp = self.layer_norm(x) + gate = self.gate_proj(gate_inp) + gate = F.relu(gate) + outputs = self.down_proj(gate) + if paddings is not None: + outputs = outputs * (1.0 - paddings[:, :, None]) + return outputs + x + + +class TimesFMAttention(nn.Module): + """Implements the attention used in TimesFM.""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + ): + super().__init__() + + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.hidden_size = hidden_size + self.head_dim = head_dim + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = nn.Parameter( + torch.empty((self.head_dim,), dtype=torch.float32), + ) + + self.qkv_proj = nn.Linear( + self.hidden_size, + (self.num_heads + 2 * self.num_kv_heads) * self.head_dim, + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size) + + def _per_dim_scaling(self, query: torch.Tensor) -> torch.Tensor: + # [batch_size, n_local_heads, input_len, head_dim] + r_softplus_0 = 1.442695041 + softplus_func = torch.nn.Softplus() + scale = r_softplus_0 / math.sqrt(self.head_dim) + scale = scale * softplus_func(self.scaling) + return query * scale[None, None, None, :] + + def forward( + self, + hidden_states: torch.Tensor, + mask: torch.Tensor, + kv_write_indices: torch.Tensor | None = None, + kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + hidden_states_shape = hidden_states.shape + assert len(hidden_states_shape) == 3 + + batch_size, input_len, _ = hidden_states_shape + + qkv = self.qkv_proj(hidden_states) + xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + xq = xq.view(batch_size, -1, self.num_heads, self.head_dim) + xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim) + xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim) + xq = self._per_dim_scaling(xq) + + # Write new kv cache. + # [batch_size, input_len, n_local_kv_heads, head_dim] + if kv_cache is not None and kv_write_indices is not None: + k_cache, v_cache = kv_cache + k_cache.index_copy_(1, kv_write_indices, xk) + v_cache.index_copy_(1, kv_write_indices, xv) + + key = k_cache + value = v_cache + else: + key = xk + value = xv + if self.num_kv_heads != self.num_heads: + # [batch_size, max_seq_len, n_local_heads, head_dim] + key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) + value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2) + + # [batch_size, n_local_heads, input_len, head_dim] + q = xq.transpose(1, 2) + # [batch_size, n_local_heads, max_seq_len, head_dim] + k = key.transpose(1, 2) + v = value.transpose(1, 2) + + # [batch_size, n_local_heads, input_len, max_seq_len] + scores = torch.matmul(q, k.transpose(2, 3)) + scores = scores + mask + scores = F.softmax(scores.float(), dim=-1).type_as(q) + + # [batch_size, n_local_heads, input_len, head_dim] + output = torch.matmul(scores, v) + # return scores, output.transpose(1, 2).contiguous() + + # [batch_size, input_len, hidden_dim] + output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) + output = self.o_proj(output) + return scores, output + + +class TimesFMDecoderLayer(nn.Module): + """Transformer layer.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + rms_norm_eps: float = 1e-6, + ): + super().__init__() + self.self_attn = TimesFMAttention( + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + ) + self.mlp = TransformerMLP( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + ) + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + mask: torch.Tensor, + paddings: torch.Tensor, + kv_write_indices: torch.Tensor | None = None, + kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + scores, hidden_states = self.self_attn( + hidden_states=hidden_states, + mask=mask, + kv_write_indices=kv_write_indices, + kv_cache=kv_cache, + ) + hidden_states = residual + hidden_states + + # MLP + hidden_states = self.mlp(hidden_states, paddings=paddings) + + return scores, hidden_states + + +class StackedDecoder(nn.Module): + """Stacked transformer layer.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + num_layers: int, + rms_norm_eps: float = 1e-6, + ): + super().__init__() + + self.layers = nn.ModuleList() + for _ in range(num_layers): + self.layers.append( + TimesFMDecoderLayer( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + rms_norm_eps=rms_norm_eps, + ) + ) + + def forward( + self, + hidden_states: torch.Tensor, + paddings: torch.Tensor, + kv_write_indices: torch.Tensor | None = None, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] | None = None, + ) -> torch.Tensor: + padding_mask = convert_paddings_to_mask(paddings, hidden_states.dtype) + atten_mask = causal_mask(hidden_states) + mask = merge_masks(padding_mask, atten_mask) + for i in range(len(self.layers)): + layer = self.layers[i] + kv_cache = kv_caches[i] if kv_caches is not None else None + _, hidden_states = layer( + hidden_states=hidden_states, + mask=mask, + paddings=paddings, + kv_write_indices=kv_write_indices, + kv_cache=kv_cache, + ) + return hidden_states + + +class PositionalEmbedding(torch.nn.Module): + """Generates position embedding for a given 1-d sequence. + + Attributes: + min_timescale: Start of the geometric index. Determines the periodicity of + the added signal. + max_timescale: End of the geometric index. Determines the frequency of the + added signal. + embedding_dims: Dimension of the embedding to be generated. + """ + + def __init__( + self, + embedding_dims: int, + min_timescale: int = 1, + max_timescale: int = 10_000, + ) -> None: + super().__init__() + self.min_timescale = min_timescale + self.max_timescale = max_timescale + self.embedding_dims = embedding_dims + + def forward(self, seq_length=None, position=None): + """Generates a Tensor of sinusoids with different frequencies. + + Args: + seq_length: an optional Python int defining the output sequence length. + if the `position` argument is specified. + position: [B, seq_length], optional position for each token in the + sequence, only required when the sequence is packed. + + Returns: + [B, seqlen, D] if `position` is specified, else [1, seqlen, D] + """ + if position is None: + assert seq_length is not None + # [1, seqlen] + position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0) + else: + assert position.ndim == 2, position.shape + + num_timescales = self.embedding_dims // 2 + log_timescale_increment = math.log( + float(self.max_timescale) / float(self.min_timescale) + ) / max(num_timescales - 1, 1) + inv_timescales = self.min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(2) * inv_timescales.unsqueeze(0).unsqueeze(0) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) + # Padding to ensure correct embedding dimension + signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) + return signal + + +class PatchedTimeSeriesDecoder(nn.Module): + """Patched time-series decoder.""" + + def __init__(self, config: TimesFMConfig): + super().__init__() + self.config = config + self.input_ff_layer = ResidualBlock( + input_dims=2 * config.patch_len, + output_dims=config.model_dim, + hidden_dims=config.model_dim, + ) + self.freq_emb = nn.Embedding(num_embeddings=3, embedding_dim=config.model_dim) + self.horizon_ff_layer = ResidualBlock( + input_dims=config.model_dim, + output_dims=config.horizon_len * (1 + len(config.quantiles)), + hidden_dims=config.model_dim, + ) + self.stacked_transformer = StackedDecoder( + hidden_size=self.config.model_dim, + intermediate_size=self.config.model_dim, + num_heads=self.config.num_heads, + num_kv_heads=self.config.num_heads, + head_dim=self.config.head_dim, + num_layers=self.config.num_layers, + rms_norm_eps=self.config.rms_norm_eps, + ) + if self.config.use_positional_embedding: + self.position_emb = PositionalEmbedding(self.config.model_dim) + + def _forward_transform( + self, inputs: torch.Tensor, patched_pads: torch.Tensor + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """Input is of shape [B, N, P].""" + mu, sigma = _masked_mean_std(inputs, patched_pads) + sigma = torch.where( + sigma < self.config.tolerance, + torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device), + sigma, + ) + + # Normalize each patch + outputs = (inputs - mu[:, None, None]) / sigma[:, None, None] + outputs = torch.where( + torch.abs(inputs - self.config.pad_val) < self.config.tolerance, + torch.tensor( + self.config.pad_val, dtype=outputs.dtype, device=outputs.device + ), + outputs, + ) + return outputs, (mu, sigma) + + def _reverse_transform( + self, outputs: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor] + ) -> torch.Tensor: + """Output is of shape [B, N, P, Q].""" + mu, sigma = stats + return outputs * sigma[:, None, None, None] + mu[:, None, None, None] + + def _preprocess_input( + self, + input_ts: torch.Tensor, + input_padding: torch.Tensor, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + tuple[torch.Tensor, torch.Tensor] | None, + torch.Tensor, + ]: + """Preprocess input for stacked transformer.""" + + # Reshape into patches (using view for efficiency) + bsize = input_ts.shape[0] + patched_inputs = input_ts.view(bsize, -1, self.config.patch_len) + patched_pads = input_padding.view(bsize, -1, self.config.patch_len) + + patched_inputs = torch.where( + torch.abs(patched_pads - 1.0) < self.config.tolerance, + torch.tensor(0.0, dtype=patched_inputs.dtype, device=patched_inputs.device), + patched_inputs, + ) + patched_pads = torch.where( + torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance, + torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device), + patched_pads, + ) + patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads) + + # B x N x D + patched_inputs = patched_inputs * (1.0 - patched_pads) + concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1) + model_input = self.input_ff_layer(concat_inputs) + + # A patch should not be padded even if there is at least one zero. + patched_padding = torch.min(patched_pads, dim=-1)[ + 0 + ] # Get the values from the min result + if self.config.use_positional_embedding: + pos_emb = self.position_emb(model_input.shape[1]).to(model_input.device) + pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0) + pos_emb = _shift_padded_seq(patched_padding, pos_emb) + model_input += pos_emb + + return model_input, patched_padding, stats, patched_inputs + + def _postprocess_output( + self, + model_output: torch.Tensor, + num_outputs: int, + stats: tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + """Postprocess output of stacked transformer.""" + + # B x N x (H.Q) + output_ts = self.horizon_ff_layer(model_output) + + # Reshape using view + b, n, _ = output_ts.shape + output_ts = output_ts.view(b, n, self.config.horizon_len, num_outputs) + + return self._reverse_transform(output_ts, stats) + + def forward( + self, + input_ts: torch.Tensor, + input_padding: torch.LongTensor, + freq: torch.Tensor, + ) -> torch.Tensor: + num_outputs = len(self.config.quantiles) + 1 + model_input, patched_padding, stats, _ = self._preprocess_input( + input_ts=input_ts, + input_padding=input_padding, + ) + f_emb = self.freq_emb(freq) # B x 1 x D + model_input += f_emb + model_output = self.stacked_transformer(model_input, patched_padding) + + output_ts = self._postprocess_output(model_output, num_outputs, stats) + return output_ts + + def decode( + self, + input_ts: torch.Tensor, + paddings: torch.Tensor, + freq: torch.LongTensor, + horizon_len: int, + output_patch_len: int | None = None, + max_len: int = 512, + return_forecast_on_context: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Auto-regressive decoding without caching. + + Args: + input_ts: input time-series and paddings. Time-series shape B x C. + paddings: padding shape B x (C + H) where H is the prediction length. + freq: frequency shape B x 1 + horizon_len: prediction length. + output_patch_len: output length to be fetched from one step of + auto-regressive decoding. + max_len: maximum training context length. + return_forecast_on_context: whether to return the model forecast on the + context except the first input patch. + + Returns: + Tuple of two forecasting results: + - Point (mean) output predictions as a tensor with shape B x H'. + - Full predictions (mean and quantiles) as a tensor with shape + B x H' x (1 + # quantiles). + In particular, if return_forecast_on_context is True, H' is H plus + the forecastable context length, i.e. context_len - (first) patch_len. + """ + final_out = input_ts + context_len = final_out.shape[1] + full_outputs = [] + if paddings.shape[1] != final_out.shape[1] + horizon_len: + raise ValueError( + "Length of paddings must match length of input + horizon_len:" + f" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}" + ) + if output_patch_len is None: + output_patch_len = self.config.horizon_len + num_decode_patches = (horizon_len + output_patch_len - 1) // output_patch_len + for step_index in range(num_decode_patches): + current_padding = paddings[:, 0 : final_out.shape[1]] + input_ts = final_out[:, -max_len:] + input_padding = current_padding[:, -max_len:] + fprop_outputs = self(input_ts, input_padding, freq) + if return_forecast_on_context and step_index == 0: + # For the first decodings step, collect the model forecast on the + # context except the unavailable first input batch forecast. + new_full_ts = fprop_outputs[:, :-1, : self.config.patch_len, :] + new_full_ts = fprop_outputs.view( + new_full_ts.size(0), -1, new_full_ts.size(3) + ) + + full_outputs.append(new_full_ts) + + # (full batch, last patch, output_patch_len, index of mean forecast = 0) + new_ts = fprop_outputs[:, -1, :output_patch_len, 0] + new_full_ts = fprop_outputs[:, -1, :output_patch_len, :] + # (full batch, last patch, output_patch_len, all output indices) + full_outputs.append(new_full_ts) + final_out = torch.concatenate([final_out, new_ts], axis=-1) + + if return_forecast_on_context: + # `full_outputs` indexing starts at after the first input patch. + full_outputs = torch.concatenate(full_outputs, axis=1)[ + :, : (context_len - self.config.patch_len + horizon_len), : + ] + else: + # `full_outputs` indexing starts at the forecast horizon. + full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0:horizon_len, :] + + return (full_outputs[:, :, 0], full_outputs) diff --git a/src/transformers/models/timesfm/timesfm_base.py b/src/transformers/models/timesfm/timesfm_base.py new file mode 100644 index 000000000000..c5f113ee6000 --- /dev/null +++ b/src/transformers/models/timesfm/timesfm_base.py @@ -0,0 +1,572 @@ +# coding=utf-8 +# Copyright 2024 Google LLC and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base class for TimesFM inference. This will be common to PAX and Pytorch.""" + +import collections +import dataclasses +import logging +import multiprocessing +from typing import Any, Literal, Sequence + +import numpy as np +import pandas as pd + +from utilsforecast.processing import make_future_dataframe +from configuration_timesfm import TimesFMConfig +import xreg_lib + +Category = xreg_lib.Category +XRegMode = xreg_lib.XRegMode + +_TOL = 1e-6 +DEFAULT_QUANTILES = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9) + + +def process_group(key, group, value_name, forecast_context_len): + group = group.tail(forecast_context_len) + return np.array(group[value_name], dtype=np.float32), key + + +def moving_average(arr, window_size): + """Calculates the moving average using NumPy's convolution function.""" + # Pad with zeros to handle initial window positions + arr_padded = np.pad(arr, (window_size - 1, 0), "constant") + smoothed_arr = np.convolve(arr_padded, np.ones(window_size), "valid") / window_size + return [smoothed_arr, arr - smoothed_arr] + + +def freq_map(freq: str): + """Returns the frequency map for the given frequency string.""" + freq = str.upper(freq) + if ( + freq.endswith("H") + or freq.endswith("T") + or freq.endswith("MIN") + or freq.endswith("D") + or freq.endswith("B") + or freq.endswith("U") + ): + return 0 + elif freq.endswith(("W", "M", "MS")): + return 1 + elif freq.endswith("Y") or freq.endswith("Q"): + return 2 + else: + raise ValueError(f"Invalid frequency: {freq}") + + +# Per time series normalization: forward. +def normalize(batch): + stats = [(np.mean(x), np.where((w := np.std(x)) > _TOL, w, 1.0)) for x in batch] + new_batch = [(x - stat[0]) / stat[1] for x, stat in zip(batch, stats)] + return new_batch, stats + + +# Per time series normalization: inverse. +def renormalize(batch, stats): + return [x * stat[1] + stat[0] for x, stat in zip(batch, stats)] + + +@dataclasses.dataclass(kw_only=True) +class TimesFmCheckpoint: + """Checkpoint used to initialize a TimesFM model for inference. + + Attributes: + version: Version of the checkpoint, e.g. "jax", "torch", "tensorflow", etc. + The factory will create the corresponding TimesFm inference class based on + this version. + path: Path to the checkpoint. + type: If provided, type of the checkpoint used by the specific checkpoint + loader per version. + step: If provided, step of the checkpoint. + """ + + version: str = "jax" + path: str | None = None + huggingface_repo_id: str | None = None + type: Any = None + step: int | None = None + + +class TimesFmBase: + """Base TimesFM forecast API for inference. + + This class is the scaffolding for calling TimesFM forecast. To properly use: + 1. Create an instance with the correct hyperparameters of a TimesFM model. + 2. Call `load_from_checkpoint` to load a compatible checkpoint. + 3. Call `forecast` for inference. + """ + + def _logging(self, s): + print(s) + + def __post_init__(self) -> None: + """Additional initialization for subclasses before checkpoint loading.""" + pass + + def __init__(self, hparams: TimesFMConfig, checkpoint: TimesFmCheckpoint) -> None: + """Initializes the TimesFM forecast API. + + Args: + hparams: Hyperparameters of the model. + checkpoint: Checkpoint to load. Notice `checkpoint.version` will decide + which TimesFM version to use. + """ + self.hparams = hparams + + # Expand hparams for conciseness within the model code. + self.context_len = hparams.context_len + self.horizon_len = hparams.horizon_len + self.input_patch_len = hparams.patch_len + self.output_patch_len = hparams.horizon_len + self.num_layers = hparams.num_layers + self.model_dims = hparams.model_dim + self.backend = hparams.backend + self.quantiles = hparams.quantiles + self.num_heads = hparams.num_heads + + # Rewrite these values in __post_init__ for SPMD. + self.num_cores = 1 + self.per_core_batch_size = hparams.per_core_batch_size + self.global_batch_size = hparams.per_core_batch_size + + self._horizon_start = self.context_len - self.input_patch_len + self.__post_init__() + self.load_from_checkpoint(checkpoint) + + def load_from_checkpoint(self, checkpoint: TimesFmCheckpoint) -> None: + """Loads a checkpoint and compiles the decoder.""" + raise NotImplementedError("`load_from_checkpoint` is not implemented.") + + def _preprocess( + self, inputs: Sequence[np.array], freq: Sequence[int] + ) -> tuple[np.array, np.array, int]: + """Formats and pads raw inputs to feed into the model. + + This function both pads each time series to match the context length, and + pads the inputs to meet the SPMD shape requirement. + + Args: + inputs: A list of 1d JTensors. Each JTensor is the context time series of + a single forecast task. + freq: list of frequencies + + Returns: + A tuple of: + - the padded input time series to meet the model required context. + - the padding indicator. + - the number of padded examples for SPMD so that each core has the same + number (a multiple of `batch_size`) of examples. + """ + + input_ts, input_padding, inp_freq = [], [], [] + + pmap_pad = ( + (len(inputs) - 1) // self.global_batch_size + 1 + ) * self.global_batch_size - len(inputs) + + for i, ts in enumerate(inputs): + input_len = ts.shape[0] + padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float) + if input_len < self.context_len: + num_front_pad = self.context_len - input_len + ts = np.concatenate( + [np.zeros(shape=(num_front_pad,), dtype=float), ts], axis=0 + ) + padding = np.concatenate( + [np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0 + ) + elif input_len > self.context_len: + ts = ts[-self.context_len :] + padding = padding[-(self.context_len + self.horizon_len) :] + + input_ts.append(ts) + input_padding.append(padding) + inp_freq.append(freq[i]) + + # Padding the remainder batch. + for _ in range(pmap_pad): + input_ts.append(input_ts[-1]) + input_padding.append(input_padding[-1]) + inp_freq.append(inp_freq[-1]) + + return ( + np.stack(input_ts, axis=0), + np.stack(input_padding, axis=0), + np.array(inp_freq).astype(np.int32).reshape(-1, 1), + pmap_pad, + ) + + def forecast( + self, + inputs: Sequence[Any], + freq: Sequence[int] | None = None, + window_size: int | None = None, + forecast_context_len: int | None = None, + return_forecast_on_context: bool = False, + truncate_negative: bool = False, + ) -> tuple[np.array, np.array]: + """Forecasts on a list of time series. + + Args: + inputs: list of time series forecast contexts. Each context time series + should be in a format convertible to JTensor by `jnp.array`. + freq: frequency of each context time series. 0 for high frequency + (default), 1 for medium, and 2 for low. Notice this is different from + the `freq` required by `forecast_on_df`. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + forecast_context_len: optional max context length. + return_forecast_on_context: True to return the forecast on the context + when available, i.e. after the first input patch. + truncate_negative: truncate to only non-negative values if all the contexts + have non-negative values. + + Returns: + A tuple for JTensors: + - the mean forecast of size (# inputs, # forecast horizon), + - the full forecast (mean + quantiles) of size + (# inputs, # forecast horizon, 1 + # quantiles). + + Raises: + ValueError: If the checkpoint is not properly loaded. + """ + raise NotImplementedError("`forecast` is not implemented.") + + def forecast_with_covariates( + self, + inputs: list[Sequence[float]], + dynamic_numerical_covariates: ( + dict[str, Sequence[Sequence[float]]] | None + ) = None, + dynamic_categorical_covariates: ( + dict[str, Sequence[Sequence[Category]]] | None + ) = None, + static_numerical_covariates: dict[str, Sequence[float]] | None = None, + static_categorical_covariates: dict[str, Sequence[Category]] | None = None, + freq: Sequence[int] | None = None, + window_size: int | None = None, + forecast_context_len: int | None = None, + xreg_mode: XRegMode = "xreg + timesfm", + normalize_xreg_target_per_input: bool = True, + ridge: float = 0.0, + max_rows_per_col: int = 0, + force_on_cpu: bool = False, + ): + """Forecasts on a list of time series with covariates. + + To optimize inference speed, avoid string valued categorical covariates. + + Args: + inputs: A list of time series forecast contexts. Each context time series + should be in a format convertible to JTensor by `jnp.array`. + dynamic_numerical_covariates: A dict of dynamic numerical covariates. + dynamic_categorical_covariates: A dict of dynamic categorical covariates. + static_numerical_covariates: A dict of static numerical covariates. + static_categorical_covariates: A dict of static categorical covariates. + freq: frequency of each context time series. 0 for high frequency + (default), 1 for medium, and 2 for low. Notice this is different from + the `freq` required by `forecast_on_df`. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + forecast_context_len: optional max context length. + xreg_mode: one of "xreg + timesfm" or "timesfm + xreg". "xreg + timesfm" + fits a model on the residuals of the TimesFM forecast. "timesfm + xreg" + fits a model on the targets then forecasts on the residuals via TimesFM. + normalize_xreg_target_per_input: whether to normalize the xreg target per + input in the given batch. + ridge: ridge penalty for the linear model. + max_rows_per_col: max number of rows per column for the linear model. + force_on_cpu: whether to force running on cpu for the linear model. + + Returns: + A tuple of two lists. The first is the outputs of the model. The second is + the outputs of the xreg. + """ + + # Verify and bookkeep covariates. + if not ( + dynamic_numerical_covariates + or dynamic_categorical_covariates + or static_numerical_covariates + or static_categorical_covariates + ): + raise ValueError( + "At least one of dynamic_numerical_covariates," + " dynamic_categorical_covariates, static_numerical_covariates," + " static_categorical_covariates must be set." + ) + + # Track the lengths of (1) each input, (2) the part that can be used in the + # linear model, and (3) the horizon. + input_lens, train_lens, test_lens = [], [], [] + + for i, input_ts in enumerate(inputs): + input_len = len(input_ts) + input_lens.append(input_len) + + if xreg_mode == "timesfm + xreg": + # For fitting residuals, no TimesFM forecast on the first patch. + train_lens.append(max(0, input_len - self.input_patch_len)) + elif xreg_mode == "xreg + timesfm": + train_lens.append(input_len) + else: + raise ValueError(f"Unsupported mode: {xreg_mode}") + + if dynamic_numerical_covariates: + test_lens.append( + len(list(dynamic_numerical_covariates.values())[0][i]) - input_len + ) + elif dynamic_categorical_covariates: + test_lens.append( + len(list(dynamic_categorical_covariates.values())[0][i]) - input_len + ) + else: + test_lens.append(self.horizon_len) + + if test_lens[-1] > self.horizon_len: + raise ValueError( + "Forecast requested longer horizon than the model definition " + f"supports: {test_lens[-1]} vs {self.horizon_len}." + ) + + # Prepare the covariates into train and test. + train_dynamic_numerical_covariates = collections.defaultdict(list) + test_dynamic_numerical_covariates = collections.defaultdict(list) + train_dynamic_categorical_covariates = collections.defaultdict(list) + test_dynamic_categorical_covariates = collections.defaultdict(list) + for covariates, train_covariates, test_covariates in ( + ( + dynamic_numerical_covariates, + train_dynamic_numerical_covariates, + test_dynamic_numerical_covariates, + ), + ( + dynamic_categorical_covariates, + train_dynamic_categorical_covariates, + test_dynamic_categorical_covariates, + ), + ): + if not covariates: + continue + for covariate_name, covariate_values in covariates.items(): + for input_len, train_len, covariate_value in zip( + input_lens, train_lens, covariate_values + ): + train_covariates[covariate_name].append( + covariate_value[(input_len - train_len) : input_len] + ) + test_covariates[covariate_name].append(covariate_value[input_len:]) + + # Fit models. + if xreg_mode == "timesfm + xreg": + # Forecast via TimesFM then fit a model on the residuals. + mean_outputs, _ = self.forecast( + inputs, + freq, + window_size, + forecast_context_len, + return_forecast_on_context=True, + ) + targets = [ + ( + np.array(input_ts)[-train_len:] + - mean_output[ + (self._horizon_start - train_len) : self._horizon_start + ] + ) + for input_ts, mean_output, train_len in zip( + inputs, mean_outputs, train_lens + ) + ] + per_instance_stats = None + if normalize_xreg_target_per_input: + targets, per_instance_stats = normalize(targets) + xregs = xreg_lib.BatchedInContextXRegLinear( + targets=targets, + train_lens=train_lens, + test_lens=test_lens, + train_dynamic_numerical_covariates=train_dynamic_numerical_covariates, + test_dynamic_numerical_covariates=test_dynamic_numerical_covariates, + train_dynamic_categorical_covariates=train_dynamic_categorical_covariates, + test_dynamic_categorical_covariates=test_dynamic_categorical_covariates, + static_numerical_covariates=static_numerical_covariates, + static_categorical_covariates=static_categorical_covariates, + ).fit( + ridge=ridge, + one_hot_encoder_drop=None if ridge > 0 else "first", + max_rows_per_col=max_rows_per_col, + force_on_cpu=force_on_cpu, + debug_info=False, + assert_covariates=True, + assert_covariate_shapes=True, + ) + if normalize_xreg_target_per_input: + xregs = renormalize(xregs, per_instance_stats) + outputs = [ + ( + mean_output[self._horizon_start : (self._horizon_start + test_len)] + + xreg + ) + for mean_output, test_len, xreg in zip(mean_outputs, test_lens, xregs) + ] + + else: + # Fit a model on the targets then forecast on the residuals via TimesFM. + targets = [ + np.array(input_ts)[-train_len:] + for input_ts, train_len in zip(inputs, train_lens) + ] + per_instance_stats = None + if normalize_xreg_target_per_input: + targets, per_instance_stats = normalize(targets) + xregs, xregs_on_context, _, _, _ = xreg_lib.BatchedInContextXRegLinear( + targets=targets, + train_lens=train_lens, + test_lens=test_lens, + train_dynamic_numerical_covariates=train_dynamic_numerical_covariates, + test_dynamic_numerical_covariates=test_dynamic_numerical_covariates, + train_dynamic_categorical_covariates=train_dynamic_categorical_covariates, + test_dynamic_categorical_covariates=test_dynamic_categorical_covariates, + static_numerical_covariates=static_numerical_covariates, + static_categorical_covariates=static_categorical_covariates, + ).fit( + ridge=ridge, + one_hot_encoder_drop=None if ridge > 0 else "first", + max_rows_per_col=max_rows_per_col, + force_on_cpu=force_on_cpu, + debug_info=True, + assert_covariates=True, + assert_covariate_shapes=True, + ) + mean_outputs, _ = self.forecast( + [ + target - xreg_on_context + for target, xreg_on_context in zip(targets, xregs_on_context) + ], + freq, + window_size, + forecast_context_len, + return_forecast_on_context=True, + ) + outputs = [ + ( + mean_output[self._horizon_start : (self._horizon_start + test_len)] + + xreg + ) + for mean_output, test_len, xreg in zip(mean_outputs, test_lens, xregs) + ] + if normalize_xreg_target_per_input: + outputs = renormalize(outputs, per_instance_stats) + + return outputs, xregs + + def forecast_on_df( + self, + inputs: pd.DataFrame, + freq: str, + forecast_context_len: int = 0, + value_name: str = "values", + model_name: str = "timesfm", + window_size: int | None = None, + num_jobs: int = 1, + verbose: bool = True, + ) -> pd.DataFrame: + """Forecasts on a list of time series. + + Args: + inputs: A pd.DataFrame of all time series. The dataframe should have a + `unique_id` column for identifying the time series, a `ds` column for + timestamps and a value column for the time series values. + freq: string valued `freq` of data. Notice this is different from the + `freq` required by `forecast`. See `freq_map` for allowed values. + forecast_context_len: If provided none zero, we take the last + `forecast_context_len` time-points from each series as the forecast + context instead of the `context_len` set by the model. + value_name: The name of the value column. + model_name: name of the model to be written into future df. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + num_jobs: number of parallel processes to use for dataframe processing. + verbose: output model states in terminal. + + Returns: + Future forecasts dataframe. + """ + if not ( + "unique_id" in inputs.columns + and "ds" in inputs.columns + and value_name in inputs.columns + ): + raise ValueError( + f"DataFrame must have unique_id, ds and {value_name} columns." + ) + if not forecast_context_len: + forecast_context_len = self.context_len + logging.info("Preprocessing dataframe.") + df_sorted = inputs.sort_values(by=["unique_id", "ds"]) + new_inputs = [] + uids = [] + if num_jobs == 1: + if verbose: + print("Processing dataframe with single process.") + for key, group in df_sorted.groupby("unique_id"): + inp, uid = process_group( + key, + group, + value_name, + forecast_context_len, + ) + new_inputs.append(inp) + uids.append(uid) + else: + if num_jobs == -1: + num_jobs = multiprocessing.cpu_count() + if verbose: + print("Processing dataframe with multiple processes.") + with multiprocessing.Pool(processes=num_jobs) as pool: + results = pool.starmap( + process_group, + [ + (key, group, value_name, forecast_context_len) + for key, group in df_sorted.groupby("unique_id") + ], + ) + new_inputs, uids = zip(*results) + if verbose: + print("Finished preprocessing dataframe.") + freq_inps = [freq_map(freq)] * len(new_inputs) + _, full_forecast = self.forecast( + new_inputs, freq=freq_inps, window_size=window_size + ) + if verbose: + print("Finished forecasting.") + fcst_df = make_future_dataframe( + uids=uids, + last_times=df_sorted.groupby("unique_id")["ds"].tail(1), + h=self.horizon_len, + freq=freq, + ) + fcst_df[model_name] = full_forecast[:, 0 : self.horizon_len, 0].reshape(-1, 1) + + for i, q in enumerate(self.quantiles): + q_col = f"{model_name}-q-{q}" + fcst_df[q_col] = full_forecast[:, 0 : self.horizon_len, 1 + i].reshape( + -1, 1 + ) + if q == 0.5: + fcst_df[model_name] = fcst_df[q_col] + logging.info("Finished creating output dataframe.") + return fcst_df diff --git a/src/transformers/models/timesfm/xreg_lib.py b/src/transformers/models/timesfm/xreg_lib.py new file mode 100644 index 000000000000..1c7d253990ca --- /dev/null +++ b/src/transformers/models/timesfm/xreg_lib.py @@ -0,0 +1,520 @@ +# coding=utf-8 +# Copyright 2024 Google LLC and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helper functions for in-context covariates and regression.""" + +import itertools +import math +from typing import Any, Iterable, Literal, Mapping, Sequence + +import jax +import jax.numpy as jnp +import numpy as np +from sklearn import preprocessing + +Category = int | str + +_TOL = 1e-6 +XRegMode = Literal["timesfm + xreg", "xreg + timesfm"] + + +def _unnest(nested: Sequence[Sequence[Any]]) -> np.ndarray: + return np.array(list(itertools.chain.from_iterable(nested))) + + +def _repeat(elements: Iterable[Any], counts: Iterable[int]) -> np.ndarray: + return np.array( + list(itertools.chain.from_iterable(map(itertools.repeat, elements, counts))) + ) + + +def _to_padded_jax_array(x: np.ndarray) -> jax.Array: + if x.ndim == 1: + (i,) = x.shape + di = 2 ** math.ceil(math.log2(i)) - i + return jnp.pad(x, ((0, di),), mode="constant", constant_values=0.0) + elif x.ndim == 2: + i, j = x.shape + di = 2 ** math.ceil(math.log2(i)) - i + dj = 2 ** math.ceil(math.log2(j)) - j + return jnp.pad(x, ((0, di), (0, dj)), mode="constant", constant_values=0.0) + else: + raise ValueError(f"Unsupported array shape: {x.shape}") + + +class BatchedInContextXRegBase: + """Helper class for in-context regression covariate formatting. + + Attributes: + targets: List of targets (responses) of the in-context regression. + train_lens: List of lengths of each target vector from the context. + test_lens: List of lengths of each forecast horizon. + train_dynamic_numerical_covariates: Dict of covariate names mapping to the + dynamic numerical covariates of each forecast task on the context. Their + lengths should match the corresponding lengths in `train_lens`. + train_dynamic_categorical_covariates: Dict of covariate names mapping to the + dynamic categorical covariates of each forecast task on the context. Their + lengths should match the corresponding lengths in `train_lens`. + test_dynamic_numerical_covariates: Dict of covariate names mapping to the + dynamic numerical covariates of each forecast task on the horizon. Their + lengths should match the corresponding lengths in `test_lens`. + test_dynamic_categorical_covariates: Dict of covariate names mapping to the + dynamic categorical covariates of each forecast task on the horizon. Their + lengths should match the corresponding lengths in `test_lens`. + static_numerical_covariates: Dict of covariate names mapping to the static + numerical covariates of each forecast task. + static_categorical_covariates: Dict of covariate names mapping to the static + categorical covariates of each forecast task. + """ + + def __init__( + self, + targets: Sequence[Sequence[float]], + train_lens: Sequence[int], + test_lens: Sequence[int], + train_dynamic_numerical_covariates: ( + Mapping[str, Sequence[Sequence[float]]] | None + ) = None, + train_dynamic_categorical_covariates: ( + Mapping[str, Sequence[Sequence[Category]]] | None + ) = None, + test_dynamic_numerical_covariates: ( + Mapping[str, Sequence[Sequence[float]]] | None + ) = None, + test_dynamic_categorical_covariates: ( + Mapping[str, Sequence[Sequence[Category]]] | None + ) = None, + static_numerical_covariates: Mapping[str, Sequence[float]] | None = None, + static_categorical_covariates: Mapping[str, Sequence[Category]] | None = None, + ) -> None: + """Initializes with the exogenous covariate inputs. + + Here we use model fitting language to refer to the context as 'train' and + the horizon as 'test'. We assume batched inputs. To properly format the + request: + + - `train_lens` represents the contexts in the batch. Targets and all train + dynamic covariates should have the same lengths as the corresponding + elements + in `train_lens`. Notice each `train_len` can be different from the exact + length of the corresponding context depending on how much of the context is + used for fitting the in-context model. + - `test_lens` represents the horizon lengths in the batch. All tesdt + dynamic + covariates should have the same lengths as the corresponding elements in + `test_lens`. + - Static covariates should be one for each input. + - For train and test dynamic covariates, they should have the same + covariate + names. + + Pass an empty dict {} for a covariate type if it is not present. + + Example: + Here is a set of valid inputs whose schema can be used for reference. + ``` + targets = [ + [0.0, 0.1, 0.2], + [0.0, 0.1, 0.2, 0.3], + ] # Two inputs in this batch. + train_lens = [3, 4] + test_lens = [2, 5] # Forecast horizons 2 and 5 respectively. + train_dynamic_numerical_covariates = { + "cov_1_dn": [[0.0, 0.5, 1.0], [0.0, 0.5, 1.0, 1.5]], + "cov_2_dn": [[0.0, 1.5, 1.0], [0.0, 1.5, 1.0, 2.5]], + } # Each train dynamic covariate has 3 and 4 elements respectively. + test_dynamic_numerical_covariates = { + "cov_1_dn": [[0.1, 0.6], [0.1, 0.6, 1.1, 1.6, 2.4]], + "cov_2_dn": [[0.1, 1.1], [0.1, 1.6, 1.1, 2.6, 10.0]], + } # Each test dynamic covariate has 2 and 5 elements respectively. + train_dynamic_categorical_covariates = { + "cov_1_dc": [[0, 1, 0], [0, 1, 2, 3]], + "cov_2_dc": [["good", "bad", "good"], ["good", "good", "bad", + "bad"]], + } + test_dynamic_categorical_covariates = { + "cov_1_dc": [[1, 0], [1, 0, 2, 3, 1]], + "cov_2_dc": [["bad", "good"], ["bad", "bad", "bad", "bad", "bad"]], + } + static_numerical_covariates = { + "cov_1_sn": [0.0, 3.0], + "cov_2_sn": [2.0, 1.0], + "cov_3_sn": [1.0, 2.0], + } # Each static covariate has 1 element for each input. + static_categorical_covariates = { + "cov_1_sc": ["apple", "orange"], + "cov_2_sc": [2, 3], + } + ``` + + Args: + targets: List of targets (responses) of the in-context regression. + train_lens: List of lengths of each target vector from the context. + test_lens: List of lengths of each forecast horizon. + train_dynamic_numerical_covariates: Dict of covariate names mapping to the + dynamic numerical covariates of each forecast task on the context. Their + lengths should match the corresponding lengths in `train_lens`. + train_dynamic_categorical_covariates: Dict of covariate names mapping to + the dynamic categorical covariates of each forecast task on the context. + Their lengths should match the corresponding lengths in `train_lens`. + test_dynamic_numerical_covariates: Dict of covariate names mapping to the + dynamic numerical covariates of each forecast task on the horizon. Their + lengths should match the corresponding lengths in `test_lens`. + test_dynamic_categorical_covariates: Dict of covariate names mapping to + the dynamic categorical covariates of each forecast task on the horizon. + Their lengths should match the corresponding lengths in `test_lens`. + static_numerical_covariates: Dict of covariate names mapping to the static + numerical covariates of each forecast task. + static_categorical_covariates: Dict of covariate names mapping to the + static categorical covariates of each forecast task. + """ + self.targets = targets + self.train_lens = train_lens + self.test_lens = test_lens + self.train_dynamic_numerical_covariates = ( + train_dynamic_numerical_covariates or {} + ) + self.train_dynamic_categorical_covariates = ( + train_dynamic_categorical_covariates or {} + ) + self.test_dynamic_numerical_covariates = test_dynamic_numerical_covariates or {} + self.test_dynamic_categorical_covariates = ( + test_dynamic_categorical_covariates or {} + ) + self.static_numerical_covariates = static_numerical_covariates or {} + self.static_categorical_covariates = static_categorical_covariates or {} + + def _assert_covariates(self, assert_covariate_shapes: bool = False) -> None: + """Verifies the validity of the covariate inputs.""" + + # Check presence. + if ( + self.train_dynamic_numerical_covariates + and not self.test_dynamic_numerical_covariates + ) or ( + not self.train_dynamic_numerical_covariates + and self.test_dynamic_numerical_covariates + ): + raise ValueError( + "train_dynamic_numerical_covariates and" + " test_dynamic_numerical_covariates must be both present or both" + " absent." + ) + + if ( + self.train_dynamic_categorical_covariates + and not self.test_dynamic_categorical_covariates + ) or ( + not self.train_dynamic_categorical_covariates + and self.test_dynamic_categorical_covariates + ): + raise ValueError( + "train_dynamic_categorical_covariates and" + " test_dynamic_categorical_covariates must be both present or both" + " absent." + ) + + # Check keys. + for dict_a, dict_b, dict_a_name, dict_b_name in ( + ( + self.train_dynamic_numerical_covariates, + self.test_dynamic_numerical_covariates, + "train_dynamic_numerical_covariates", + "test_dynamic_numerical_covariates", + ), + ( + self.train_dynamic_categorical_covariates, + self.test_dynamic_categorical_covariates, + "train_dynamic_categorical_covariates", + "test_dynamic_categorical_covariates", + ), + ): + if w := set(dict_a.keys()) - set(dict_b.keys()): + raise ValueError( + f"{dict_a_name} has keys not present in {dict_b_name}: {w}" + ) + if w := set(dict_b.keys()) - set(dict_a.keys()): + raise ValueError( + f"{dict_b_name} has keys not present in {dict_a_name}: {w}" + ) + + # Check shapes. + if assert_covariate_shapes: + if len(self.targets) != len(self.train_lens): + raise ValueError( + "targets and train_lens must have the same number of elements." + ) + + if len(self.train_lens) != len(self.test_lens): + raise ValueError( + "train_lens and test_lens must have the same number of elements." + ) + + for i, (target, train_len) in enumerate(zip(self.targets, self.train_lens)): + if len(target) != train_len: + raise ValueError( + f"targets[{i}] has length {len(target)} != expected {train_len}." + ) + + for key, values in self.static_numerical_covariates.items(): + if len(values) != len(self.train_lens): + raise ValueError( + f"static_numerical_covariates has key {key} with number of" + f" examples {len(values)} != expected {len(self.train_lens)}." + ) + + for key, values in self.static_categorical_covariates.items(): + if len(values) != len(self.train_lens): + raise ValueError( + f"static_categorical_covariates has key {key} with number of" + f" examples {len(values)} != expected {len(self.train_lens)}." + ) + + for lens, dict_cov, dict_cov_name in ( + ( + self.train_lens, + self.train_dynamic_numerical_covariates, + "train_dynamic_numerical_covariates", + ), + ( + self.train_lens, + self.train_dynamic_categorical_covariates, + "train_dynamic_categorical_covariates", + ), + ( + self.test_lens, + self.test_dynamic_numerical_covariates, + "test_dynamic_numerical_covariates", + ), + ( + self.test_lens, + self.test_dynamic_categorical_covariates, + "test_dynamic_categorical_covariates", + ), + ): + for key, cov_values in dict_cov.items(): + if len(cov_values) != len(lens): + raise ValueError( + f"{dict_cov_name} has key {key} with number of examples" + f" {len(cov_values)} != expected {len(lens)}." + ) + for i, cov_value in enumerate(cov_values): + if len(cov_value) != lens[i]: + raise ValueError( + f"{dict_cov_name} has key {key} with its {i}-th example" + f" length {len(cov_value)} != expected {lens[i]}." + ) + + def create_covariate_matrix( + self, + one_hot_encoder_drop: str | None = "first", + use_intercept: bool = True, + assert_covariates: bool = False, + assert_covariate_shapes: bool = False, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Creates target vector and covariate matrices for in context regression. + + Here we use model fitting language to refer to the context as 'train' and + the horizon as 'test'. + + Args: + one_hot_encoder_drop: Which drop strategy to use for the one hot encoder. + use_intercept: Whether to prepare an intercept (all 1) column in the + matrices. + assert_covariates: Whether to assert the validity of the covariate inputs. + assert_covariate_shapes: Whether to assert the shapes of the covariate + inputs when `assert_covariates` is True. + + Returns: + A tuple of the target vector, the covariate matrix for the context, and + the covariate matrix for the horizon. + """ + if assert_covariates: + self._assert_covariates(assert_covariate_shapes) + + x_train, x_test = [], [] + + # Numerical features. + for name in sorted(self.train_dynamic_numerical_covariates): + x_train.append( + _unnest(self.train_dynamic_numerical_covariates[name])[:, np.newaxis] + ) + x_test.append( + _unnest(self.test_dynamic_numerical_covariates[name])[:, np.newaxis] + ) + + for covs in self.static_numerical_covariates.values(): + x_train.append(_repeat(covs, self.train_lens)[:, np.newaxis]) + x_test.append(_repeat(covs, self.test_lens)[:, np.newaxis]) + + if x_train: + x_train = np.concatenate(x_train, axis=1) + x_test = np.concatenate(x_test, axis=1) + + # Normalize for robustness. + x_mean = np.mean(x_train, axis=0, keepdims=True) + x_std = np.where( + (w := np.std(x_train, axis=0, keepdims=True)) > _TOL, w, 1.0 + ) + x_train = [(x_train - x_mean) / x_std] + x_test = [(x_test - x_mean) / x_std] + + # Categorical features. Encode one by one. + one_hot_encoder = preprocessing.OneHotEncoder( + drop=one_hot_encoder_drop, + sparse_output=False, + handle_unknown="ignore", + ) + for name in sorted(self.train_dynamic_categorical_covariates.keys()): + ohe_train = _unnest(self.train_dynamic_categorical_covariates[name])[ + :, np.newaxis + ] + ohe_test = _unnest(self.test_dynamic_categorical_covariates[name])[ + :, np.newaxis + ] + x_train.append(np.array(one_hot_encoder.fit_transform(ohe_train))) + x_test.append(np.array(one_hot_encoder.transform(ohe_test))) + + for covs in self.static_categorical_covariates.values(): + ohe = one_hot_encoder.fit_transform(np.array(covs)[:, np.newaxis]) + x_train.append(_repeat(ohe, self.train_lens)) + x_test.append(_repeat(ohe, self.test_lens)) + + x_train = np.concatenate(x_train, axis=1) + x_test = np.concatenate(x_test, axis=1) + + if use_intercept: + x_train = np.pad(x_train, ((0, 0), (1, 0)), constant_values=1.0) + x_test = np.pad(x_test, ((0, 0), (1, 0)), constant_values=1.0) + + return _unnest(self.targets), x_train, x_test + + def fit(self) -> Any: + raise NotImplementedError("Fit is not implemented.") + + +class BatchedInContextXRegLinear(BatchedInContextXRegBase): + """Linear in-context regression model.""" + + def fit( + self, + ridge: float = 0.0, + one_hot_encoder_drop: str | None = "first", + use_intercept: bool = True, + force_on_cpu: bool = False, + max_rows_per_col: int = 0, + max_rows_per_col_sample_seed: int = 42, + debug_info: bool = False, + assert_covariates: bool = False, + assert_covariate_shapes: bool = False, + ) -> ( + list[np.ndarray] + | tuple[list[np.ndarray], list[np.ndarray], jax.Array, jax.Array, jax.Array] + ): + """Fits a linear model for in-context regression. + + Args: + ridge: A non-negative value for specifying the ridge regression penalty. + If 0 is provided, fallback to ordinary least squares. Note this penalty + is added to the normalized covariate matrix. + one_hot_encoder_drop: Which drop strategy to use for the one hot encoder. + use_intercept: Whether to prepare an intercept (all 1) column in the + matrices. + force_on_cpu: Whether to force execution on cpu for accelerator machines. + max_rows_per_col: How many rows to subsample per column. 0 for no + subsampling. This is for speeding up model fitting. + max_rows_per_col_sample_seed: The seed for the subsampling if needed by + `max_rows_per_col`. + debug_info: Whether to return debug info. + assert_covariates: Whether to assert the validity of the covariate inputs. + assert_covariate_shapes: Whether to assert the shapes of the covariate + inputs when `assert_covariates` is True. + + Returns: + If `debug_info` is False: + The linear fits on the horizon. + If `debug_info` is True: + A tuple of: + - the linear fits on the horizon, + - the linear fits on the context, + - the flattened target vector, + - the covariate matrix for the context, and + - the covariate matrix for the horizon. + """ + flat_targets, x_train_raw, x_test = self.create_covariate_matrix( + one_hot_encoder_drop=one_hot_encoder_drop, + use_intercept=use_intercept, + assert_covariates=assert_covariates, + assert_covariate_shapes=assert_covariate_shapes, + ) + + x_train = x_train_raw.copy() + if max_rows_per_col: + nrows, ncols = x_train.shape + if nrows > (w := ncols * max_rows_per_col): + subsample = jax.random.choice( + jax.random.PRNGKey(max_rows_per_col_sample_seed), + nrows, + (w,), + replace=False, + ) + x_train = x_train[subsample] + flat_targets = flat_targets[subsample] + + device = jax.devices("cpu")[0] if force_on_cpu else None + # Runs jitted version of the solvers which are quicker at the cost of + # running jitting during the first time calling. Re-jitting happens whenever + # new (padded) shapes are encountered. + # Ocassionally it helps with the speed and the accuracy if we force single + # thread execution on cpu for accelerator machines: + # 1. Avoid moving data to accelarator memory. + # 2. Avoid precision loss if any. + with jax.default_device(device): + x_train_raw = _to_padded_jax_array(x_train_raw) + x_train = _to_padded_jax_array(x_train) + flat_targets = _to_padded_jax_array(flat_targets) + x_test = _to_padded_jax_array(x_test) + beta_hat = ( + jnp.linalg.pinv( + x_train.T @ x_train + ridge * jnp.eye(x_train.shape[1]), + hermitian=True, + ) + @ x_train.T + @ flat_targets + ) + y_hat = x_test @ beta_hat + y_hat_context = x_train_raw @ beta_hat if debug_info else None + + outputs = [] + outputs_context = [] + + # Reconstruct the ragged 2-dim batched forecasts from flattened linear fits. + train_index, test_index = 0, 0 + for train_index_delta, test_index_delta in zip(self.train_lens, self.test_lens): + outputs.append( + np.array(y_hat[test_index : (test_index + test_index_delta)]) + ) + if debug_info: + outputs_context.append( + np.array( + y_hat_context[train_index : (train_index + train_index_delta)] + ) + ) + train_index += train_index_delta + test_index += test_index_delta + + if debug_info: + return outputs, outputs_context, flat_targets, x_train, x_test + else: + return outputs From 7e0305aea487cd3e02c99712567639dd26bd7019 Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Sun, 22 Sep 2024 19:07:03 -0700 Subject: [PATCH 094/242] remove covariate forecasting --- .../{modeling_timesfm.py => timesfm.py} | 0 .../models/timesfm/timesfm_base.py | 232 -------- src/transformers/models/timesfm/xreg_lib.py | 520 ------------------ 3 files changed, 752 deletions(-) rename src/transformers/models/timesfm/{modeling_timesfm.py => timesfm.py} (100%) delete mode 100644 src/transformers/models/timesfm/xreg_lib.py diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/timesfm.py similarity index 100% rename from src/transformers/models/timesfm/modeling_timesfm.py rename to src/transformers/models/timesfm/timesfm.py diff --git a/src/transformers/models/timesfm/timesfm_base.py b/src/transformers/models/timesfm/timesfm_base.py index c5f113ee6000..7c0c756e6847 100644 --- a/src/transformers/models/timesfm/timesfm_base.py +++ b/src/transformers/models/timesfm/timesfm_base.py @@ -25,10 +25,6 @@ from utilsforecast.processing import make_future_dataframe from configuration_timesfm import TimesFMConfig -import xreg_lib - -Category = xreg_lib.Category -XRegMode = xreg_lib.XRegMode _TOL = 1e-6 DEFAULT_QUANTILES = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9) @@ -245,234 +241,6 @@ def forecast( """ raise NotImplementedError("`forecast` is not implemented.") - def forecast_with_covariates( - self, - inputs: list[Sequence[float]], - dynamic_numerical_covariates: ( - dict[str, Sequence[Sequence[float]]] | None - ) = None, - dynamic_categorical_covariates: ( - dict[str, Sequence[Sequence[Category]]] | None - ) = None, - static_numerical_covariates: dict[str, Sequence[float]] | None = None, - static_categorical_covariates: dict[str, Sequence[Category]] | None = None, - freq: Sequence[int] | None = None, - window_size: int | None = None, - forecast_context_len: int | None = None, - xreg_mode: XRegMode = "xreg + timesfm", - normalize_xreg_target_per_input: bool = True, - ridge: float = 0.0, - max_rows_per_col: int = 0, - force_on_cpu: bool = False, - ): - """Forecasts on a list of time series with covariates. - - To optimize inference speed, avoid string valued categorical covariates. - - Args: - inputs: A list of time series forecast contexts. Each context time series - should be in a format convertible to JTensor by `jnp.array`. - dynamic_numerical_covariates: A dict of dynamic numerical covariates. - dynamic_categorical_covariates: A dict of dynamic categorical covariates. - static_numerical_covariates: A dict of static numerical covariates. - static_categorical_covariates: A dict of static categorical covariates. - freq: frequency of each context time series. 0 for high frequency - (default), 1 for medium, and 2 for low. Notice this is different from - the `freq` required by `forecast_on_df`. - window_size: window size of trend + residual decomposition. If None then - we do not do decomposition. - forecast_context_len: optional max context length. - xreg_mode: one of "xreg + timesfm" or "timesfm + xreg". "xreg + timesfm" - fits a model on the residuals of the TimesFM forecast. "timesfm + xreg" - fits a model on the targets then forecasts on the residuals via TimesFM. - normalize_xreg_target_per_input: whether to normalize the xreg target per - input in the given batch. - ridge: ridge penalty for the linear model. - max_rows_per_col: max number of rows per column for the linear model. - force_on_cpu: whether to force running on cpu for the linear model. - - Returns: - A tuple of two lists. The first is the outputs of the model. The second is - the outputs of the xreg. - """ - - # Verify and bookkeep covariates. - if not ( - dynamic_numerical_covariates - or dynamic_categorical_covariates - or static_numerical_covariates - or static_categorical_covariates - ): - raise ValueError( - "At least one of dynamic_numerical_covariates," - " dynamic_categorical_covariates, static_numerical_covariates," - " static_categorical_covariates must be set." - ) - - # Track the lengths of (1) each input, (2) the part that can be used in the - # linear model, and (3) the horizon. - input_lens, train_lens, test_lens = [], [], [] - - for i, input_ts in enumerate(inputs): - input_len = len(input_ts) - input_lens.append(input_len) - - if xreg_mode == "timesfm + xreg": - # For fitting residuals, no TimesFM forecast on the first patch. - train_lens.append(max(0, input_len - self.input_patch_len)) - elif xreg_mode == "xreg + timesfm": - train_lens.append(input_len) - else: - raise ValueError(f"Unsupported mode: {xreg_mode}") - - if dynamic_numerical_covariates: - test_lens.append( - len(list(dynamic_numerical_covariates.values())[0][i]) - input_len - ) - elif dynamic_categorical_covariates: - test_lens.append( - len(list(dynamic_categorical_covariates.values())[0][i]) - input_len - ) - else: - test_lens.append(self.horizon_len) - - if test_lens[-1] > self.horizon_len: - raise ValueError( - "Forecast requested longer horizon than the model definition " - f"supports: {test_lens[-1]} vs {self.horizon_len}." - ) - - # Prepare the covariates into train and test. - train_dynamic_numerical_covariates = collections.defaultdict(list) - test_dynamic_numerical_covariates = collections.defaultdict(list) - train_dynamic_categorical_covariates = collections.defaultdict(list) - test_dynamic_categorical_covariates = collections.defaultdict(list) - for covariates, train_covariates, test_covariates in ( - ( - dynamic_numerical_covariates, - train_dynamic_numerical_covariates, - test_dynamic_numerical_covariates, - ), - ( - dynamic_categorical_covariates, - train_dynamic_categorical_covariates, - test_dynamic_categorical_covariates, - ), - ): - if not covariates: - continue - for covariate_name, covariate_values in covariates.items(): - for input_len, train_len, covariate_value in zip( - input_lens, train_lens, covariate_values - ): - train_covariates[covariate_name].append( - covariate_value[(input_len - train_len) : input_len] - ) - test_covariates[covariate_name].append(covariate_value[input_len:]) - - # Fit models. - if xreg_mode == "timesfm + xreg": - # Forecast via TimesFM then fit a model on the residuals. - mean_outputs, _ = self.forecast( - inputs, - freq, - window_size, - forecast_context_len, - return_forecast_on_context=True, - ) - targets = [ - ( - np.array(input_ts)[-train_len:] - - mean_output[ - (self._horizon_start - train_len) : self._horizon_start - ] - ) - for input_ts, mean_output, train_len in zip( - inputs, mean_outputs, train_lens - ) - ] - per_instance_stats = None - if normalize_xreg_target_per_input: - targets, per_instance_stats = normalize(targets) - xregs = xreg_lib.BatchedInContextXRegLinear( - targets=targets, - train_lens=train_lens, - test_lens=test_lens, - train_dynamic_numerical_covariates=train_dynamic_numerical_covariates, - test_dynamic_numerical_covariates=test_dynamic_numerical_covariates, - train_dynamic_categorical_covariates=train_dynamic_categorical_covariates, - test_dynamic_categorical_covariates=test_dynamic_categorical_covariates, - static_numerical_covariates=static_numerical_covariates, - static_categorical_covariates=static_categorical_covariates, - ).fit( - ridge=ridge, - one_hot_encoder_drop=None if ridge > 0 else "first", - max_rows_per_col=max_rows_per_col, - force_on_cpu=force_on_cpu, - debug_info=False, - assert_covariates=True, - assert_covariate_shapes=True, - ) - if normalize_xreg_target_per_input: - xregs = renormalize(xregs, per_instance_stats) - outputs = [ - ( - mean_output[self._horizon_start : (self._horizon_start + test_len)] - + xreg - ) - for mean_output, test_len, xreg in zip(mean_outputs, test_lens, xregs) - ] - - else: - # Fit a model on the targets then forecast on the residuals via TimesFM. - targets = [ - np.array(input_ts)[-train_len:] - for input_ts, train_len in zip(inputs, train_lens) - ] - per_instance_stats = None - if normalize_xreg_target_per_input: - targets, per_instance_stats = normalize(targets) - xregs, xregs_on_context, _, _, _ = xreg_lib.BatchedInContextXRegLinear( - targets=targets, - train_lens=train_lens, - test_lens=test_lens, - train_dynamic_numerical_covariates=train_dynamic_numerical_covariates, - test_dynamic_numerical_covariates=test_dynamic_numerical_covariates, - train_dynamic_categorical_covariates=train_dynamic_categorical_covariates, - test_dynamic_categorical_covariates=test_dynamic_categorical_covariates, - static_numerical_covariates=static_numerical_covariates, - static_categorical_covariates=static_categorical_covariates, - ).fit( - ridge=ridge, - one_hot_encoder_drop=None if ridge > 0 else "first", - max_rows_per_col=max_rows_per_col, - force_on_cpu=force_on_cpu, - debug_info=True, - assert_covariates=True, - assert_covariate_shapes=True, - ) - mean_outputs, _ = self.forecast( - [ - target - xreg_on_context - for target, xreg_on_context in zip(targets, xregs_on_context) - ], - freq, - window_size, - forecast_context_len, - return_forecast_on_context=True, - ) - outputs = [ - ( - mean_output[self._horizon_start : (self._horizon_start + test_len)] - + xreg - ) - for mean_output, test_len, xreg in zip(mean_outputs, test_lens, xregs) - ] - if normalize_xreg_target_per_input: - outputs = renormalize(outputs, per_instance_stats) - - return outputs, xregs - def forecast_on_df( self, inputs: pd.DataFrame, diff --git a/src/transformers/models/timesfm/xreg_lib.py b/src/transformers/models/timesfm/xreg_lib.py deleted file mode 100644 index 1c7d253990ca..000000000000 --- a/src/transformers/models/timesfm/xreg_lib.py +++ /dev/null @@ -1,520 +0,0 @@ -# coding=utf-8 -# Copyright 2024 Google LLC and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Helper functions for in-context covariates and regression.""" - -import itertools -import math -from typing import Any, Iterable, Literal, Mapping, Sequence - -import jax -import jax.numpy as jnp -import numpy as np -from sklearn import preprocessing - -Category = int | str - -_TOL = 1e-6 -XRegMode = Literal["timesfm + xreg", "xreg + timesfm"] - - -def _unnest(nested: Sequence[Sequence[Any]]) -> np.ndarray: - return np.array(list(itertools.chain.from_iterable(nested))) - - -def _repeat(elements: Iterable[Any], counts: Iterable[int]) -> np.ndarray: - return np.array( - list(itertools.chain.from_iterable(map(itertools.repeat, elements, counts))) - ) - - -def _to_padded_jax_array(x: np.ndarray) -> jax.Array: - if x.ndim == 1: - (i,) = x.shape - di = 2 ** math.ceil(math.log2(i)) - i - return jnp.pad(x, ((0, di),), mode="constant", constant_values=0.0) - elif x.ndim == 2: - i, j = x.shape - di = 2 ** math.ceil(math.log2(i)) - i - dj = 2 ** math.ceil(math.log2(j)) - j - return jnp.pad(x, ((0, di), (0, dj)), mode="constant", constant_values=0.0) - else: - raise ValueError(f"Unsupported array shape: {x.shape}") - - -class BatchedInContextXRegBase: - """Helper class for in-context regression covariate formatting. - - Attributes: - targets: List of targets (responses) of the in-context regression. - train_lens: List of lengths of each target vector from the context. - test_lens: List of lengths of each forecast horizon. - train_dynamic_numerical_covariates: Dict of covariate names mapping to the - dynamic numerical covariates of each forecast task on the context. Their - lengths should match the corresponding lengths in `train_lens`. - train_dynamic_categorical_covariates: Dict of covariate names mapping to the - dynamic categorical covariates of each forecast task on the context. Their - lengths should match the corresponding lengths in `train_lens`. - test_dynamic_numerical_covariates: Dict of covariate names mapping to the - dynamic numerical covariates of each forecast task on the horizon. Their - lengths should match the corresponding lengths in `test_lens`. - test_dynamic_categorical_covariates: Dict of covariate names mapping to the - dynamic categorical covariates of each forecast task on the horizon. Their - lengths should match the corresponding lengths in `test_lens`. - static_numerical_covariates: Dict of covariate names mapping to the static - numerical covariates of each forecast task. - static_categorical_covariates: Dict of covariate names mapping to the static - categorical covariates of each forecast task. - """ - - def __init__( - self, - targets: Sequence[Sequence[float]], - train_lens: Sequence[int], - test_lens: Sequence[int], - train_dynamic_numerical_covariates: ( - Mapping[str, Sequence[Sequence[float]]] | None - ) = None, - train_dynamic_categorical_covariates: ( - Mapping[str, Sequence[Sequence[Category]]] | None - ) = None, - test_dynamic_numerical_covariates: ( - Mapping[str, Sequence[Sequence[float]]] | None - ) = None, - test_dynamic_categorical_covariates: ( - Mapping[str, Sequence[Sequence[Category]]] | None - ) = None, - static_numerical_covariates: Mapping[str, Sequence[float]] | None = None, - static_categorical_covariates: Mapping[str, Sequence[Category]] | None = None, - ) -> None: - """Initializes with the exogenous covariate inputs. - - Here we use model fitting language to refer to the context as 'train' and - the horizon as 'test'. We assume batched inputs. To properly format the - request: - - - `train_lens` represents the contexts in the batch. Targets and all train - dynamic covariates should have the same lengths as the corresponding - elements - in `train_lens`. Notice each `train_len` can be different from the exact - length of the corresponding context depending on how much of the context is - used for fitting the in-context model. - - `test_lens` represents the horizon lengths in the batch. All tesdt - dynamic - covariates should have the same lengths as the corresponding elements in - `test_lens`. - - Static covariates should be one for each input. - - For train and test dynamic covariates, they should have the same - covariate - names. - - Pass an empty dict {} for a covariate type if it is not present. - - Example: - Here is a set of valid inputs whose schema can be used for reference. - ``` - targets = [ - [0.0, 0.1, 0.2], - [0.0, 0.1, 0.2, 0.3], - ] # Two inputs in this batch. - train_lens = [3, 4] - test_lens = [2, 5] # Forecast horizons 2 and 5 respectively. - train_dynamic_numerical_covariates = { - "cov_1_dn": [[0.0, 0.5, 1.0], [0.0, 0.5, 1.0, 1.5]], - "cov_2_dn": [[0.0, 1.5, 1.0], [0.0, 1.5, 1.0, 2.5]], - } # Each train dynamic covariate has 3 and 4 elements respectively. - test_dynamic_numerical_covariates = { - "cov_1_dn": [[0.1, 0.6], [0.1, 0.6, 1.1, 1.6, 2.4]], - "cov_2_dn": [[0.1, 1.1], [0.1, 1.6, 1.1, 2.6, 10.0]], - } # Each test dynamic covariate has 2 and 5 elements respectively. - train_dynamic_categorical_covariates = { - "cov_1_dc": [[0, 1, 0], [0, 1, 2, 3]], - "cov_2_dc": [["good", "bad", "good"], ["good", "good", "bad", - "bad"]], - } - test_dynamic_categorical_covariates = { - "cov_1_dc": [[1, 0], [1, 0, 2, 3, 1]], - "cov_2_dc": [["bad", "good"], ["bad", "bad", "bad", "bad", "bad"]], - } - static_numerical_covariates = { - "cov_1_sn": [0.0, 3.0], - "cov_2_sn": [2.0, 1.0], - "cov_3_sn": [1.0, 2.0], - } # Each static covariate has 1 element for each input. - static_categorical_covariates = { - "cov_1_sc": ["apple", "orange"], - "cov_2_sc": [2, 3], - } - ``` - - Args: - targets: List of targets (responses) of the in-context regression. - train_lens: List of lengths of each target vector from the context. - test_lens: List of lengths of each forecast horizon. - train_dynamic_numerical_covariates: Dict of covariate names mapping to the - dynamic numerical covariates of each forecast task on the context. Their - lengths should match the corresponding lengths in `train_lens`. - train_dynamic_categorical_covariates: Dict of covariate names mapping to - the dynamic categorical covariates of each forecast task on the context. - Their lengths should match the corresponding lengths in `train_lens`. - test_dynamic_numerical_covariates: Dict of covariate names mapping to the - dynamic numerical covariates of each forecast task on the horizon. Their - lengths should match the corresponding lengths in `test_lens`. - test_dynamic_categorical_covariates: Dict of covariate names mapping to - the dynamic categorical covariates of each forecast task on the horizon. - Their lengths should match the corresponding lengths in `test_lens`. - static_numerical_covariates: Dict of covariate names mapping to the static - numerical covariates of each forecast task. - static_categorical_covariates: Dict of covariate names mapping to the - static categorical covariates of each forecast task. - """ - self.targets = targets - self.train_lens = train_lens - self.test_lens = test_lens - self.train_dynamic_numerical_covariates = ( - train_dynamic_numerical_covariates or {} - ) - self.train_dynamic_categorical_covariates = ( - train_dynamic_categorical_covariates or {} - ) - self.test_dynamic_numerical_covariates = test_dynamic_numerical_covariates or {} - self.test_dynamic_categorical_covariates = ( - test_dynamic_categorical_covariates or {} - ) - self.static_numerical_covariates = static_numerical_covariates or {} - self.static_categorical_covariates = static_categorical_covariates or {} - - def _assert_covariates(self, assert_covariate_shapes: bool = False) -> None: - """Verifies the validity of the covariate inputs.""" - - # Check presence. - if ( - self.train_dynamic_numerical_covariates - and not self.test_dynamic_numerical_covariates - ) or ( - not self.train_dynamic_numerical_covariates - and self.test_dynamic_numerical_covariates - ): - raise ValueError( - "train_dynamic_numerical_covariates and" - " test_dynamic_numerical_covariates must be both present or both" - " absent." - ) - - if ( - self.train_dynamic_categorical_covariates - and not self.test_dynamic_categorical_covariates - ) or ( - not self.train_dynamic_categorical_covariates - and self.test_dynamic_categorical_covariates - ): - raise ValueError( - "train_dynamic_categorical_covariates and" - " test_dynamic_categorical_covariates must be both present or both" - " absent." - ) - - # Check keys. - for dict_a, dict_b, dict_a_name, dict_b_name in ( - ( - self.train_dynamic_numerical_covariates, - self.test_dynamic_numerical_covariates, - "train_dynamic_numerical_covariates", - "test_dynamic_numerical_covariates", - ), - ( - self.train_dynamic_categorical_covariates, - self.test_dynamic_categorical_covariates, - "train_dynamic_categorical_covariates", - "test_dynamic_categorical_covariates", - ), - ): - if w := set(dict_a.keys()) - set(dict_b.keys()): - raise ValueError( - f"{dict_a_name} has keys not present in {dict_b_name}: {w}" - ) - if w := set(dict_b.keys()) - set(dict_a.keys()): - raise ValueError( - f"{dict_b_name} has keys not present in {dict_a_name}: {w}" - ) - - # Check shapes. - if assert_covariate_shapes: - if len(self.targets) != len(self.train_lens): - raise ValueError( - "targets and train_lens must have the same number of elements." - ) - - if len(self.train_lens) != len(self.test_lens): - raise ValueError( - "train_lens and test_lens must have the same number of elements." - ) - - for i, (target, train_len) in enumerate(zip(self.targets, self.train_lens)): - if len(target) != train_len: - raise ValueError( - f"targets[{i}] has length {len(target)} != expected {train_len}." - ) - - for key, values in self.static_numerical_covariates.items(): - if len(values) != len(self.train_lens): - raise ValueError( - f"static_numerical_covariates has key {key} with number of" - f" examples {len(values)} != expected {len(self.train_lens)}." - ) - - for key, values in self.static_categorical_covariates.items(): - if len(values) != len(self.train_lens): - raise ValueError( - f"static_categorical_covariates has key {key} with number of" - f" examples {len(values)} != expected {len(self.train_lens)}." - ) - - for lens, dict_cov, dict_cov_name in ( - ( - self.train_lens, - self.train_dynamic_numerical_covariates, - "train_dynamic_numerical_covariates", - ), - ( - self.train_lens, - self.train_dynamic_categorical_covariates, - "train_dynamic_categorical_covariates", - ), - ( - self.test_lens, - self.test_dynamic_numerical_covariates, - "test_dynamic_numerical_covariates", - ), - ( - self.test_lens, - self.test_dynamic_categorical_covariates, - "test_dynamic_categorical_covariates", - ), - ): - for key, cov_values in dict_cov.items(): - if len(cov_values) != len(lens): - raise ValueError( - f"{dict_cov_name} has key {key} with number of examples" - f" {len(cov_values)} != expected {len(lens)}." - ) - for i, cov_value in enumerate(cov_values): - if len(cov_value) != lens[i]: - raise ValueError( - f"{dict_cov_name} has key {key} with its {i}-th example" - f" length {len(cov_value)} != expected {lens[i]}." - ) - - def create_covariate_matrix( - self, - one_hot_encoder_drop: str | None = "first", - use_intercept: bool = True, - assert_covariates: bool = False, - assert_covariate_shapes: bool = False, - ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - """Creates target vector and covariate matrices for in context regression. - - Here we use model fitting language to refer to the context as 'train' and - the horizon as 'test'. - - Args: - one_hot_encoder_drop: Which drop strategy to use for the one hot encoder. - use_intercept: Whether to prepare an intercept (all 1) column in the - matrices. - assert_covariates: Whether to assert the validity of the covariate inputs. - assert_covariate_shapes: Whether to assert the shapes of the covariate - inputs when `assert_covariates` is True. - - Returns: - A tuple of the target vector, the covariate matrix for the context, and - the covariate matrix for the horizon. - """ - if assert_covariates: - self._assert_covariates(assert_covariate_shapes) - - x_train, x_test = [], [] - - # Numerical features. - for name in sorted(self.train_dynamic_numerical_covariates): - x_train.append( - _unnest(self.train_dynamic_numerical_covariates[name])[:, np.newaxis] - ) - x_test.append( - _unnest(self.test_dynamic_numerical_covariates[name])[:, np.newaxis] - ) - - for covs in self.static_numerical_covariates.values(): - x_train.append(_repeat(covs, self.train_lens)[:, np.newaxis]) - x_test.append(_repeat(covs, self.test_lens)[:, np.newaxis]) - - if x_train: - x_train = np.concatenate(x_train, axis=1) - x_test = np.concatenate(x_test, axis=1) - - # Normalize for robustness. - x_mean = np.mean(x_train, axis=0, keepdims=True) - x_std = np.where( - (w := np.std(x_train, axis=0, keepdims=True)) > _TOL, w, 1.0 - ) - x_train = [(x_train - x_mean) / x_std] - x_test = [(x_test - x_mean) / x_std] - - # Categorical features. Encode one by one. - one_hot_encoder = preprocessing.OneHotEncoder( - drop=one_hot_encoder_drop, - sparse_output=False, - handle_unknown="ignore", - ) - for name in sorted(self.train_dynamic_categorical_covariates.keys()): - ohe_train = _unnest(self.train_dynamic_categorical_covariates[name])[ - :, np.newaxis - ] - ohe_test = _unnest(self.test_dynamic_categorical_covariates[name])[ - :, np.newaxis - ] - x_train.append(np.array(one_hot_encoder.fit_transform(ohe_train))) - x_test.append(np.array(one_hot_encoder.transform(ohe_test))) - - for covs in self.static_categorical_covariates.values(): - ohe = one_hot_encoder.fit_transform(np.array(covs)[:, np.newaxis]) - x_train.append(_repeat(ohe, self.train_lens)) - x_test.append(_repeat(ohe, self.test_lens)) - - x_train = np.concatenate(x_train, axis=1) - x_test = np.concatenate(x_test, axis=1) - - if use_intercept: - x_train = np.pad(x_train, ((0, 0), (1, 0)), constant_values=1.0) - x_test = np.pad(x_test, ((0, 0), (1, 0)), constant_values=1.0) - - return _unnest(self.targets), x_train, x_test - - def fit(self) -> Any: - raise NotImplementedError("Fit is not implemented.") - - -class BatchedInContextXRegLinear(BatchedInContextXRegBase): - """Linear in-context regression model.""" - - def fit( - self, - ridge: float = 0.0, - one_hot_encoder_drop: str | None = "first", - use_intercept: bool = True, - force_on_cpu: bool = False, - max_rows_per_col: int = 0, - max_rows_per_col_sample_seed: int = 42, - debug_info: bool = False, - assert_covariates: bool = False, - assert_covariate_shapes: bool = False, - ) -> ( - list[np.ndarray] - | tuple[list[np.ndarray], list[np.ndarray], jax.Array, jax.Array, jax.Array] - ): - """Fits a linear model for in-context regression. - - Args: - ridge: A non-negative value for specifying the ridge regression penalty. - If 0 is provided, fallback to ordinary least squares. Note this penalty - is added to the normalized covariate matrix. - one_hot_encoder_drop: Which drop strategy to use for the one hot encoder. - use_intercept: Whether to prepare an intercept (all 1) column in the - matrices. - force_on_cpu: Whether to force execution on cpu for accelerator machines. - max_rows_per_col: How many rows to subsample per column. 0 for no - subsampling. This is for speeding up model fitting. - max_rows_per_col_sample_seed: The seed for the subsampling if needed by - `max_rows_per_col`. - debug_info: Whether to return debug info. - assert_covariates: Whether to assert the validity of the covariate inputs. - assert_covariate_shapes: Whether to assert the shapes of the covariate - inputs when `assert_covariates` is True. - - Returns: - If `debug_info` is False: - The linear fits on the horizon. - If `debug_info` is True: - A tuple of: - - the linear fits on the horizon, - - the linear fits on the context, - - the flattened target vector, - - the covariate matrix for the context, and - - the covariate matrix for the horizon. - """ - flat_targets, x_train_raw, x_test = self.create_covariate_matrix( - one_hot_encoder_drop=one_hot_encoder_drop, - use_intercept=use_intercept, - assert_covariates=assert_covariates, - assert_covariate_shapes=assert_covariate_shapes, - ) - - x_train = x_train_raw.copy() - if max_rows_per_col: - nrows, ncols = x_train.shape - if nrows > (w := ncols * max_rows_per_col): - subsample = jax.random.choice( - jax.random.PRNGKey(max_rows_per_col_sample_seed), - nrows, - (w,), - replace=False, - ) - x_train = x_train[subsample] - flat_targets = flat_targets[subsample] - - device = jax.devices("cpu")[0] if force_on_cpu else None - # Runs jitted version of the solvers which are quicker at the cost of - # running jitting during the first time calling. Re-jitting happens whenever - # new (padded) shapes are encountered. - # Ocassionally it helps with the speed and the accuracy if we force single - # thread execution on cpu for accelerator machines: - # 1. Avoid moving data to accelarator memory. - # 2. Avoid precision loss if any. - with jax.default_device(device): - x_train_raw = _to_padded_jax_array(x_train_raw) - x_train = _to_padded_jax_array(x_train) - flat_targets = _to_padded_jax_array(flat_targets) - x_test = _to_padded_jax_array(x_test) - beta_hat = ( - jnp.linalg.pinv( - x_train.T @ x_train + ridge * jnp.eye(x_train.shape[1]), - hermitian=True, - ) - @ x_train.T - @ flat_targets - ) - y_hat = x_test @ beta_hat - y_hat_context = x_train_raw @ beta_hat if debug_info else None - - outputs = [] - outputs_context = [] - - # Reconstruct the ragged 2-dim batched forecasts from flattened linear fits. - train_index, test_index = 0, 0 - for train_index_delta, test_index_delta in zip(self.train_lens, self.test_lens): - outputs.append( - np.array(y_hat[test_index : (test_index + test_index_delta)]) - ) - if debug_info: - outputs_context.append( - np.array( - y_hat_context[train_index : (train_index + train_index_delta)] - ) - ) - train_index += train_index_delta - test_index += test_index_delta - - if debug_info: - return outputs, outputs_context, flat_targets, x_train, x_test - else: - return outputs From c042a9d7bbe67cc006493ea03fe7a2b1f819c215 Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Mon, 30 Sep 2024 13:03:37 -0700 Subject: [PATCH 095/242] Adapting TimesFM to HF format --- src/transformers/models/timesfm/__init__.py | 6 +- .../models/timesfm/configuration_timesfm.py | 6 +- .../models/timesfm/modeling_timesfm.py | 626 ++++++++++++++++++ .../models/timesfm/patched_decoder.py | 22 +- src/transformers/models/timesfm/timesfm.py | 202 ------ .../models/timesfm/timesfm_base.py | 340 ---------- 6 files changed, 643 insertions(+), 559 deletions(-) create mode 100644 src/transformers/models/timesfm/modeling_timesfm.py delete mode 100644 src/transformers/models/timesfm/timesfm.py delete mode 100644 src/transformers/models/timesfm/timesfm_base.py diff --git a/src/transformers/models/timesfm/__init__.py b/src/transformers/models/timesfm/__init__.py index baa30b11af21..fe1a08da2678 100644 --- a/src/transformers/models/timesfm/__init__.py +++ b/src/transformers/models/timesfm/__init__.py @@ -29,8 +29,7 @@ except OptionalDependencyNotAvailable: pass else: - _import_structure["modeling_timesfm"] = [ - "TimesFMForPrediction", + _import_structure["timesfm"] = [ "TimesFMModel", "TimesFMPreTrainedModel", ] @@ -44,8 +43,7 @@ except OptionalDependencyNotAvailable: pass else: - from .modeling_timesfm import ( - TimesFMForPrediction, + from .timesfm import ( TimesFMModel, TimesFMPreTrainedModel, ) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index de82a874771b..29948593aff5 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -71,6 +71,8 @@ class TimesFMConfig(PretrainedConfig): initializer_factor (`float`, *optional*, defaults to 1): A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing). + backend (`str`, *optional*, defaults to `"gpu"`): + The backend to use for the model. Can be either `"gpu"` or `"cpu"`. """ model_type = "timesfm" @@ -97,8 +99,9 @@ def __init__( quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], pad_val: float = 1123581321.0, use_positional_embedding: bool = True, - per_core_batch_size: int = 32, + per_core_batch_size: int = 32, initializer_factor: float = 1.0, + backend: str = "gpu", **kwargs, ): self.patch_len = patch_len @@ -117,6 +120,7 @@ def __init__( self.use_positional_embedding = use_positional_embedding self.per_core_batch_size = per_core_batch_size self.initializer_factor = initializer_factor + self.backend = backend super().__init__( **kwargs, diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py new file mode 100644 index 000000000000..f2df0c061129 --- /dev/null +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -0,0 +1,626 @@ +# coding=utf-8 +# Copyright 2024 Google LLC and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch TimesFM model.""" + + +#################################################### +# PyTorch Models are constructed by sub-classing +# - torch.nn.Module for the layers and +# - PreTrainedModel for the models (it-self a sub-class of nn.Module) +#################################################### + + +import dataclasses +import logging +import multiprocessing +from typing import Any, Sequence +from os import path +import pandas as pd +import numpy as np +import torch +import torch.nn as nn +from huggingface_hub import snapshot_download +from ...modeling_utils import PreTrainedModel + +import patched_decoder as ppd +from utilsforecast.processing import make_future_dataframe +from configuration_timesfm import TimesFMConfig + + +def process_group(key, group, value_name, forecast_context_len): + group = group.tail(forecast_context_len) + return np.array(group[value_name], dtype=np.float32), key + + +def moving_average(arr, window_size): + """Calculates the moving average using NumPy's convolution function.""" + # Pad with zeros to handle initial window positions + arr_padded = np.pad(arr, (window_size - 1, 0), "constant") + smoothed_arr = np.convolve(arr_padded, np.ones(window_size), "valid") / window_size + return [smoothed_arr, arr - smoothed_arr] + + +def freq_map(freq: str): + """Returns the frequency map for the given frequency string.""" + freq = str.upper(freq) + if ( + freq.endswith("H") + or freq.endswith("T") + or freq.endswith("MIN") + or freq.endswith("D") + or freq.endswith("B") + or freq.endswith("U") + ): + return 0 + elif freq.endswith(("W", "M", "MS")): + return 1 + elif freq.endswith("Y") or freq.endswith("Q"): + return 2 + else: + raise ValueError(f"Invalid frequency: {freq}") + + +@dataclasses.dataclass(kw_only=True) +class TimesFmCheckpoint: + """Checkpoint used to initialize a TimesFM model for inference. + + Attributes: + version: Version of the checkpoint, e.g. "jax", "torch", "tensorflow", etc. + The factory will create the corresponding TimesFm inference class based on + this version. + path: Path to the checkpoint. + type: If provided, type of the checkpoint used by the specific checkpoint + loader per version. + step: If provided, step of the checkpoint. + """ + + version: str = "torch" + path: str | None = None + huggingface_repo_id: str | None = None + type: Any = None + step: int | None = None + + +class TimesFmBase: + """Base TimesFM forecast API for inference. + + This class is the scaffolding for calling TimesFM forecast. To properly use: + 1. Create an instance with the correct hyperparameters of a TimesFM model. + 2. Call `load_from_checkpoint` to load a compatible checkpoint. + 3. Call `forecast` for inference. + """ + + def _logging(self, s): + print(s) + + def __init__(self, hparams: TimesFMConfig) -> None: + """Initializes the TimesFM forecast API. + + Args: + hparams: Hyperparameters of the model. + checkpoint: Checkpoint to load. Notice `checkpoint.version` will decide + which TimesFM version to use. + """ + self.hparams = hparams + + # Expand hparams for conciseness within the model code. + self.context_len = hparams.context_len + self.horizon_len = hparams.horizon_len + self.input_patch_len = hparams.patch_len + self.output_patch_len = hparams.horizon_len + self.num_layers = hparams.num_layers + self.model_dims = hparams.model_dim + self.backend = hparams.backend + self.quantiles = hparams.quantiles + self.num_heads = hparams.num_heads + + # Rewrite these values in subclasses for SPMD. + self.num_cores = 1 + self.per_core_batch_size = hparams.per_core_batch_size + self.global_batch_size = hparams.per_core_batch_size + self._horizon_start = self.context_len - self.input_patch_len + + def load_from_checkpoint(self, checkpoint: TimesFmCheckpoint) -> None: + """Loads a checkpoint and compiles the decoder.""" + raise NotImplementedError("`load_from_checkpoint` is not implemented.") + + def _preprocess( + self, inputs: Sequence[np.array], freq: Sequence[int] + ) -> tuple[np.array, np.array, int]: + """Formats and pads raw inputs to feed into the model. + + This function both pads each time series to match the context length, and + pads the inputs to meet the SPMD shape requirement. + + Args: + inputs: A list of 1d JTensors. Each JTensor is the context time series of + a single forecast task. + freq: list of frequencies + + Returns: + A tuple of: + - the padded input time series to meet the model required context. + - the padding indicator. + - the number of padded examples for SPMD so that each core has the same + number (a multiple of `batch_size`) of examples. + """ + + input_ts, input_padding, inp_freq = [], [], [] + + pmap_pad = ( + (len(inputs) - 1) // self.global_batch_size + 1 + ) * self.global_batch_size - len(inputs) + + for i, ts in enumerate(inputs): + input_len = ts.shape[0] + padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float) + if input_len < self.context_len: + num_front_pad = self.context_len - input_len + ts = np.concatenate( + [np.zeros(shape=(num_front_pad,), dtype=float), ts], axis=0 + ) + padding = np.concatenate( + [np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0 + ) + elif input_len > self.context_len: + ts = ts[-self.context_len :] + padding = padding[-(self.context_len + self.horizon_len) :] + + input_ts.append(ts) + input_padding.append(padding) + inp_freq.append(freq[i]) + + # Padding the remainder batch. + for _ in range(pmap_pad): + input_ts.append(input_ts[-1]) + input_padding.append(input_padding[-1]) + inp_freq.append(inp_freq[-1]) + + return ( + np.stack(input_ts, axis=0), + np.stack(input_padding, axis=0), + np.array(inp_freq).astype(np.int32).reshape(-1, 1), + pmap_pad, + ) + + def forecast( + self, + inputs: Sequence[Any], + freq: Sequence[int] | None = None, + window_size: int | None = None, + forecast_context_len: int | None = None, + return_forecast_on_context: bool = False, + truncate_negative: bool = False, + ) -> tuple[np.array, np.array]: + """Forecasts on a list of time series. + + Args: + inputs: list of time series forecast contexts. Each context time series + should be in a format convertible to JTensor by `jnp.array`. + freq: frequency of each context time series. 0 for high frequency + (default), 1 for medium, and 2 for low. Notice this is different from + the `freq` required by `forecast_on_df`. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + forecast_context_len: optional max context length. + return_forecast_on_context: True to return the forecast on the context + when available, i.e. after the first input patch. + truncate_negative: truncate to only non-negative values if all the contexts + have non-negative values. + + Returns: + A tuple for JTensors: + - the mean forecast of size (# inputs, # forecast horizon), + - the full forecast (mean + quantiles) of size + (# inputs, # forecast horizon, 1 + # quantiles). + + Raises: + ValueError: If the checkpoint is not properly loaded. + """ + raise NotImplementedError("`forecast` is not implemented.") + + def forecast_on_df( + self, + inputs: pd.DataFrame, + freq: str, + forecast_context_len: int = 0, + value_name: str = "values", + model_name: str = "timesfm", + window_size: int | None = None, + num_jobs: int = 1, + verbose: bool = True, + ) -> pd.DataFrame: + """Forecasts on a list of time series. + + Args: + inputs: A pd.DataFrame of all time series. The dataframe should have a + `unique_id` column for identifying the time series, a `ds` column for + timestamps and a value column for the time series values. + freq: string valued `freq` of data. Notice this is different from the + `freq` required by `forecast`. See `freq_map` for allowed values. + forecast_context_len: If provided none zero, we take the last + `forecast_context_len` time-points from each series as the forecast + context instead of the `context_len` set by the model. + value_name: The name of the value column. + model_name: name of the model to be written into future df. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + num_jobs: number of parallel processes to use for dataframe processing. + verbose: output model states in terminal. + + Returns: + Future forecasts dataframe. + """ + if not ( + "unique_id" in inputs.columns + and "ds" in inputs.columns + and value_name in inputs.columns + ): + raise ValueError( + f"DataFrame must have unique_id, ds and {value_name} columns." + ) + if not forecast_context_len: + forecast_context_len = self.context_len + logging.info("Preprocessing dataframe.") + df_sorted = inputs.sort_values(by=["unique_id", "ds"]) + new_inputs = [] + uids = [] + if num_jobs == 1: + if verbose: + print("Processing dataframe with single process.") + for key, group in df_sorted.groupby("unique_id"): + inp, uid = process_group( + key, + group, + value_name, + forecast_context_len, + ) + new_inputs.append(inp) + uids.append(uid) + else: + if num_jobs == -1: + num_jobs = multiprocessing.cpu_count() + if verbose: + print("Processing dataframe with multiple processes.") + with multiprocessing.Pool(processes=num_jobs) as pool: + results = pool.starmap( + process_group, + [ + (key, group, value_name, forecast_context_len) + for key, group in df_sorted.groupby("unique_id") + ], + ) + new_inputs, uids = zip(*results) + if verbose: + print("Finished preprocessing dataframe.") + freq_inps = [freq_map(freq)] * len(new_inputs) + _, full_forecast = self.forecast( + new_inputs, freq=freq_inps, window_size=window_size + ) + if verbose: + print("Finished forecasting.") + fcst_df = make_future_dataframe( + uids=uids, + last_times=df_sorted.groupby("unique_id")["ds"].tail(1), + h=self.horizon_len, + freq=freq, + ) + fcst_df[model_name] = full_forecast[:, 0 : self.horizon_len, 0].reshape(-1, 1) + + for i, q in enumerate(self.quantiles): + q_col = f"{model_name}-q-{q}" + fcst_df[q_col] = full_forecast[:, 0 : self.horizon_len, 1 + i].reshape( + -1, 1 + ) + if q == 0.5: + fcst_df[model_name] = fcst_df[q_col] + logging.info("Finished creating output dataframe.") + return fcst_df + + +class TimesFMModel(TimesFmBase, nn.Module): + """Body of the TimesFM model, excluding the head.""" + + def __post_init__(self): + self._model_config = TimesFMConfig( + num_layers=self.num_layers, + num_heads=self.num_heads, + hidden_size=self.model_dims, + intermediate_size=self.model_dims, + patch_len=self.input_patch_len, + horizon_len=self.output_patch_len, + head_dim=self.model_dims // self.num_heads, + quantiles=self.quantiles, + ) + + self.num_cores = 1 + self.global_batch_size = self.per_core_batch_size + self._device = torch.device( + "cuda:0" if (torch.cuda.is_available() and self.backend == "gpu") else "cpu" + ) + self._model = ppd.PatchedTimeSeriesDecoder(self._model_config) + self._model.to(self._device) + self._model.eval() + + def forecast( + self, + inputs: Sequence[Any], + freq: Sequence[int] | None = None, + window_size: int | None = None, + forecast_context_len: int | None = None, + return_forecast_on_context: bool = False, + truncate_negative: bool = False, + ) -> tuple[np.ndarray, np.ndarray]: + """Forecasts on a list of time series. + + Args: + inputs: list of time series forecast contexts. Each context time series + should be in a format convertible to JTensor by `jnp.array`. + freq: frequency of each context time series. 0 for high frequency + (default), 1 for medium, and 2 for low. Notice this is different from + the `freq` required by `forecast_on_df`. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + forecast_context_len: optional max context length. + return_forecast_on_context: True to return the forecast on the context + when available, i.e. after the first input patch. + truncate_negative: truncate to only non-negative values if all the contexts + have non-negative values. + + Returns: + A tuple for JTensors: + - the mean forecast of size (# inputs, # forecast horizon), + - the full forecast (mean + quantiles) of size + (# inputs, # forecast horizon, 1 + # quantiles). + + Raises: + ValueError: If the checkpoint is not properly loaded. + """ + if forecast_context_len is None: + fcontext_len = self.context_len + else: + fcontext_len = forecast_context_len + inputs = [np.array(ts)[-fcontext_len:] for ts in inputs] + inp_min = np.min([np.min(ts) for ts in inputs]) + + if window_size is not None: + new_inputs = [] + for ts in inputs: + new_inputs.extend(moving_average(ts, window_size)) + inputs = new_inputs + + if freq is None: + logging.info("No frequency provided via `freq`. Default to high (0).") + freq = [0] * len(inputs) + + input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq) + with torch.no_grad(): + mean_outputs = [] + full_outputs = [] + assert input_ts.shape[0] % self.global_batch_size == 0 + for i in range(input_ts.shape[0] // self.global_batch_size): + input_ts_in = torch.from_numpy( + np.array( + input_ts[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size + ], + dtype=np.float32, + ) + ).to(self._device) + input_padding_in = torch.from_numpy( + np.array( + input_padding[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size + ], + dtype=np.float32, + ) + ).to(self._device) + inp_freq_in = ( + torch.from_numpy( + np.array( + inp_freq[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size, + :, + ], + dtype=np.int32, + ) + ) + .long() + .to(self._device) + ) + mean_output, full_output = self._model.decode( + input_ts=input_ts_in, + paddings=input_padding_in, + freq=inp_freq_in, + horizon_len=self.horizon_len, + return_forecast_on_context=return_forecast_on_context, + ) + mean_output = mean_output.detach().cpu().numpy() + full_output = full_output.detach().cpu().numpy() + mean_output = np.array(mean_output) + full_output = np.array(full_output) + mean_outputs.append(mean_output) + full_outputs.append(full_output) + + mean_outputs = np.concatenate(mean_outputs, axis=0) + full_outputs = np.concatenate(full_outputs, axis=0) + + if pmap_pad > 0: + mean_outputs = mean_outputs[:-pmap_pad, ...] + full_outputs = full_outputs[:-pmap_pad, ...] + + if window_size is not None: + mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] + full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] + if inp_min >= 0 and truncate_negative: + mean_outputs = np.maximum(mean_outputs, 0.0) + full_outputs = np.maximum(full_outputs, 0.0) + return mean_outputs, full_outputs + + +class TimesFMModel(TimesFmBase, nn.Module): + """TimesFM forecast API for inference.""" + + def __init__(self, hparams: TimesFMConfig) -> None: + super.__init__(hparams) + self._model_config = hparams + self._model = ppd.PatchedTimeSeriesDecoder(self._model_config) + self.num_cores = 1 + self.global_batch_size = self.per_core_batch_size + self._device = torch.device( + "cuda:0" if (torch.cuda.is_available() and self.backend == "gpu") else "cpu" + ) + + def forecast( + self, + inputs: Sequence[Any], + freq: Sequence[int] | None = None, + window_size: int | None = None, + forecast_context_len: int | None = None, + return_forecast_on_context: bool = False, + truncate_negative: bool = False, + ) -> tuple[np.ndarray, np.ndarray]: + """Forecasts on a list of time series. + + Args: + inputs: list of time series forecast contexts. Each context time series + should be in a format convertible to JTensor by `jnp.array`. + freq: frequency of each context time series. 0 for high frequency + (default), 1 for medium, and 2 for low. Notice this is different from + the `freq` required by `forecast_on_df`. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + forecast_context_len: optional max context length. + return_forecast_on_context: True to return the forecast on the context + when available, i.e. after the first input patch. + truncate_negative: truncate to only non-negative values if all the contexts + have non-negative values. + + Returns: + A tuple for JTensors: + - the mean forecast of size (# inputs, # forecast horizon), + - the full forecast (mean + quantiles) of size + (# inputs, # forecast horizon, 1 + # quantiles). + + Raises: + ValueError: If the checkpoint is not properly loaded. + """ + if not self._model: + raise ValueError( + "Checkpoint not loaded. Call `load_from_checkpoint` before" + " `forecast`." + ) + if forecast_context_len is None: + fcontext_len = self.context_len + else: + fcontext_len = forecast_context_len + inputs = [np.array(ts)[-fcontext_len:] for ts in inputs] + inp_min = np.min([np.min(ts) for ts in inputs]) + + if window_size is not None: + new_inputs = [] + for ts in inputs: + new_inputs.extend(moving_average(ts, window_size)) + inputs = new_inputs + + if freq is None: + logging.info("No frequency provided via `freq`. Default to high (0).") + freq = [0] * len(inputs) + + input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq) + with torch.no_grad(): + mean_outputs = [] + full_outputs = [] + assert input_ts.shape[0] % self.global_batch_size == 0 + for i in range(input_ts.shape[0] // self.global_batch_size): + input_ts_in = torch.from_numpy( + np.array( + input_ts[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size + ], + dtype=np.float32, + ) + ).to(self._device) + input_padding_in = torch.from_numpy( + np.array( + input_padding[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size + ], + dtype=np.float32, + ) + ).to(self._device) + inp_freq_in = ( + torch.from_numpy( + np.array( + inp_freq[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size, + :, + ], + dtype=np.int32, + ) + ) + .long() + .to(self._device) + ) + mean_output, full_output = self._model.decode( + input_ts=input_ts_in, + paddings=input_padding_in, + freq=inp_freq_in, + horizon_len=self.horizon_len, + return_forecast_on_context=return_forecast_on_context, + ) + mean_output = mean_output.detach().cpu().numpy() + full_output = full_output.detach().cpu().numpy() + mean_output = np.array(mean_output) + full_output = np.array(full_output) + mean_outputs.append(mean_output) + full_outputs.append(full_output) + + mean_outputs = np.concatenate(mean_outputs, axis=0) + full_outputs = np.concatenate(full_outputs, axis=0) + + if pmap_pad > 0: + mean_outputs = mean_outputs[:-pmap_pad, ...] + full_outputs = full_outputs[:-pmap_pad, ...] + + if window_size is not None: + mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] + full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] + if inp_min >= 0 and truncate_negative: + mean_outputs = np.maximum(mean_outputs, 0.0) + full_outputs = np.maximum(full_outputs, 0.0) + return mean_outputs, full_outputs + + def forward(self, x, **kwargs): + if isinstance(x, pd.DataFrame): + assert "freq" in kwargs, "Frequency must be provided for DataFrame input." + return self.forecast_on_df(x, **kwargs) + else: + return self.forecast(x, **kwargs) + + +## TODO: Define the PreTrainedTimesFMModel class diff --git a/src/transformers/models/timesfm/patched_decoder.py b/src/transformers/models/timesfm/patched_decoder.py index f7e108bc08d8..baafe6be148c 100644 --- a/src/transformers/models/timesfm/patched_decoder.py +++ b/src/transformers/models/timesfm/patched_decoder.py @@ -31,22 +31,21 @@ def _masked_mean_std( It excludes values where `padding` is 1. Args: - inputs: A PyTorch tensor of shape [b, n, p]. - padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1. + inputs: A PyTorch tensor of shape [b, n, p]. + padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1. Returns: - A tuple containing the mean and standard deviation. - We return the statistics of the first patch with more than three non-padded - values. + A tuple containing the mean and standard deviation. + We return the statistics of the first patch with more than three non-padded values. """ - # Selecting the first patch with more than 3 unpadded values. - pad_sum = torch.sum(1 - padding, dim=2) + # Selecting the first patch with more than 3 unpadded values. def _get_patch_index(arr: torch.Tensor): indices = torch.argmax((arr >= 3).to(torch.int32), dim=1) row_sum = (arr >= 3).to(torch.int32).sum(dim=1) return torch.where(row_sum == 0, arr.shape[1] - 1, indices) + pad_sum = torch.sum(1 - padding, dim=2) patch_indices = _get_patch_index(pad_sum) bidxs = torch.arange(inputs.shape[0]) @@ -57,9 +56,8 @@ def _get_patch_index(arr: torch.Tensor): mask = 1 - pad # Calculate the number of valid elements - num_valid_elements = torch.sum(mask, dim=1) num_valid_elements = torch.where( - num_valid_elements == 0, + torch.sum(mask, dim=1) == 0, torch.tensor( 1, dtype=num_valid_elements.dtype, device=num_valid_elements.device ), @@ -87,11 +85,11 @@ def _shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor: """Shifts rows of seq based on the first 0 in each row of the mask. Args: - mask: mask tensor of shape [B, N] - seq: seq tensor of shape [B, N, P] + mask: mask tensor of shape [B, N] + seq: seq tensor of shape [B, N, P] Returns: - Returns the shifted sequence. + The shifted sequence. """ batch_size, num_seq, feature_dim = seq.shape diff --git a/src/transformers/models/timesfm/timesfm.py b/src/transformers/models/timesfm/timesfm.py deleted file mode 100644 index ea27c1e75b8c..000000000000 --- a/src/transformers/models/timesfm/timesfm.py +++ /dev/null @@ -1,202 +0,0 @@ -# coding=utf-8 -# Copyright 2024 Google LLC and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch TimesFM model.""" - - -#################################################### -# PyTorch Models are constructed by sub-classing -# - torch.nn.Module for the layers and -# - PreTrainedModel for the models (it-self a sub-class of nn.Module) -#################################################### - -import logging -from os import path -from typing import Any, Sequence - -import numpy as np -import torch -from huggingface_hub import snapshot_download -import timesfm_base -import patched_decoder as ppd -from ...modeling_utils import PreTrainedModel - - -_TOL = 1e-6 - - -class TimesFmTorch(PreTrainedModel, timesfm_base.TimesFmBase): - """TimesFM forecast API for inference.""" - - def __post_init__(self): - self._model_config = ppd.TimesFMConfig( - num_layers=self.num_layers, - num_heads=self.num_heads, - hidden_size=self.model_dims, - intermediate_size=self.model_dims, - patch_len=self.input_patch_len, - horizon_len=self.output_patch_len, - head_dim=self.model_dims // self.num_heads, - quantiles=self.quantiles, - ) - self._model = None - self.num_cores = 1 - self.global_batch_size = self.per_core_batch_size - self._device = torch.device( - "cuda:0" if (torch.cuda.is_available() and self.backend == "gpu") else "cpu" - ) - - def load_from_checkpoint( - self, - checkpoint: timesfm_base.TimesFmCheckpoint, - ) -> None: - """Loads a checkpoint and compiles the decoder.""" - checkpoint_path = checkpoint.path - repo_id = checkpoint.huggingface_repo_id - if checkpoint_path is None: - checkpoint_path = path.join(snapshot_download(repo_id), "torch_model.ckpt") - self._model = ppd.PatchedTimeSeriesDecoder(self._model_config) - loaded_checkpoint = torch.load(checkpoint_path, weights_only=True) - logging.info("Loading checkpoint from %s", checkpoint_path) - self._model.load_state_dict(loaded_checkpoint) - logging.info("Sending checkpoint to device %s", f"{self._device}") - self._model.to(self._device) - self._model.eval() - # TODO: add compilation. - - def forecast( - self, - inputs: Sequence[Any], - freq: Sequence[int] | None = None, - window_size: int | None = None, - forecast_context_len: int | None = None, - return_forecast_on_context: bool = False, - truncate_negative: bool = False, - ) -> tuple[np.ndarray, np.ndarray]: - """Forecasts on a list of time series. - - Args: - inputs: list of time series forecast contexts. Each context time series - should be in a format convertible to JTensor by `jnp.array`. - freq: frequency of each context time series. 0 for high frequency - (default), 1 for medium, and 2 for low. Notice this is different from - the `freq` required by `forecast_on_df`. - window_size: window size of trend + residual decomposition. If None then - we do not do decomposition. - forecast_context_len: optional max context length. - return_forecast_on_context: True to return the forecast on the context - when available, i.e. after the first input patch. - truncate_negative: truncate to only non-negative values if all the contexts - have non-negative values. - - Returns: - A tuple for JTensors: - - the mean forecast of size (# inputs, # forecast horizon), - - the full forecast (mean + quantiles) of size - (# inputs, # forecast horizon, 1 + # quantiles). - - Raises: - ValueError: If the checkpoint is not properly loaded. - """ - if not self._model: - raise ValueError( - "Checkpoint not loaded. Call `load_from_checkpoint` before" - " `forecast`." - ) - if forecast_context_len is None: - fcontext_len = self.context_len - else: - fcontext_len = forecast_context_len - inputs = [np.array(ts)[-fcontext_len:] for ts in inputs] - inp_min = np.min([np.min(ts) for ts in inputs]) - - if window_size is not None: - new_inputs = [] - for ts in inputs: - new_inputs.extend(timesfm_base.moving_average(ts, window_size)) - inputs = new_inputs - - if freq is None: - logging.info("No frequency provided via `freq`. Default to high (0).") - freq = [0] * len(inputs) - - input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq) - with torch.no_grad(): - mean_outputs = [] - full_outputs = [] - assert input_ts.shape[0] % self.global_batch_size == 0 - for i in range(input_ts.shape[0] // self.global_batch_size): - input_ts_in = torch.from_numpy( - np.array( - input_ts[ - i - * self.global_batch_size : (i + 1) - * self.global_batch_size - ], - dtype=np.float32, - ) - ).to(self._device) - input_padding_in = torch.from_numpy( - np.array( - input_padding[ - i - * self.global_batch_size : (i + 1) - * self.global_batch_size - ], - dtype=np.float32, - ) - ).to(self._device) - inp_freq_in = ( - torch.from_numpy( - np.array( - inp_freq[ - i - * self.global_batch_size : (i + 1) - * self.global_batch_size, - :, - ], - dtype=np.int32, - ) - ) - .long() - .to(self._device) - ) - mean_output, full_output = self._model.decode( - input_ts=input_ts_in, - paddings=input_padding_in, - freq=inp_freq_in, - horizon_len=self.horizon_len, - return_forecast_on_context=return_forecast_on_context, - ) - mean_output = mean_output.detach().cpu().numpy() - full_output = full_output.detach().cpu().numpy() - mean_output = np.array(mean_output) - full_output = np.array(full_output) - mean_outputs.append(mean_output) - full_outputs.append(full_output) - - mean_outputs = np.concatenate(mean_outputs, axis=0) - full_outputs = np.concatenate(full_outputs, axis=0) - - if pmap_pad > 0: - mean_outputs = mean_outputs[:-pmap_pad, ...] - full_outputs = full_outputs[:-pmap_pad, ...] - - if window_size is not None: - mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] - full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] - if inp_min >= 0 and truncate_negative: - mean_outputs = np.maximum(mean_outputs, 0.0) - full_outputs = np.maximum(full_outputs, 0.0) - return mean_outputs, full_outputs diff --git a/src/transformers/models/timesfm/timesfm_base.py b/src/transformers/models/timesfm/timesfm_base.py deleted file mode 100644 index 7c0c756e6847..000000000000 --- a/src/transformers/models/timesfm/timesfm_base.py +++ /dev/null @@ -1,340 +0,0 @@ -# coding=utf-8 -# Copyright 2024 Google LLC and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Base class for TimesFM inference. This will be common to PAX and Pytorch.""" - -import collections -import dataclasses -import logging -import multiprocessing -from typing import Any, Literal, Sequence - -import numpy as np -import pandas as pd - -from utilsforecast.processing import make_future_dataframe -from configuration_timesfm import TimesFMConfig - -_TOL = 1e-6 -DEFAULT_QUANTILES = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9) - - -def process_group(key, group, value_name, forecast_context_len): - group = group.tail(forecast_context_len) - return np.array(group[value_name], dtype=np.float32), key - - -def moving_average(arr, window_size): - """Calculates the moving average using NumPy's convolution function.""" - # Pad with zeros to handle initial window positions - arr_padded = np.pad(arr, (window_size - 1, 0), "constant") - smoothed_arr = np.convolve(arr_padded, np.ones(window_size), "valid") / window_size - return [smoothed_arr, arr - smoothed_arr] - - -def freq_map(freq: str): - """Returns the frequency map for the given frequency string.""" - freq = str.upper(freq) - if ( - freq.endswith("H") - or freq.endswith("T") - or freq.endswith("MIN") - or freq.endswith("D") - or freq.endswith("B") - or freq.endswith("U") - ): - return 0 - elif freq.endswith(("W", "M", "MS")): - return 1 - elif freq.endswith("Y") or freq.endswith("Q"): - return 2 - else: - raise ValueError(f"Invalid frequency: {freq}") - - -# Per time series normalization: forward. -def normalize(batch): - stats = [(np.mean(x), np.where((w := np.std(x)) > _TOL, w, 1.0)) for x in batch] - new_batch = [(x - stat[0]) / stat[1] for x, stat in zip(batch, stats)] - return new_batch, stats - - -# Per time series normalization: inverse. -def renormalize(batch, stats): - return [x * stat[1] + stat[0] for x, stat in zip(batch, stats)] - - -@dataclasses.dataclass(kw_only=True) -class TimesFmCheckpoint: - """Checkpoint used to initialize a TimesFM model for inference. - - Attributes: - version: Version of the checkpoint, e.g. "jax", "torch", "tensorflow", etc. - The factory will create the corresponding TimesFm inference class based on - this version. - path: Path to the checkpoint. - type: If provided, type of the checkpoint used by the specific checkpoint - loader per version. - step: If provided, step of the checkpoint. - """ - - version: str = "jax" - path: str | None = None - huggingface_repo_id: str | None = None - type: Any = None - step: int | None = None - - -class TimesFmBase: - """Base TimesFM forecast API for inference. - - This class is the scaffolding for calling TimesFM forecast. To properly use: - 1. Create an instance with the correct hyperparameters of a TimesFM model. - 2. Call `load_from_checkpoint` to load a compatible checkpoint. - 3. Call `forecast` for inference. - """ - - def _logging(self, s): - print(s) - - def __post_init__(self) -> None: - """Additional initialization for subclasses before checkpoint loading.""" - pass - - def __init__(self, hparams: TimesFMConfig, checkpoint: TimesFmCheckpoint) -> None: - """Initializes the TimesFM forecast API. - - Args: - hparams: Hyperparameters of the model. - checkpoint: Checkpoint to load. Notice `checkpoint.version` will decide - which TimesFM version to use. - """ - self.hparams = hparams - - # Expand hparams for conciseness within the model code. - self.context_len = hparams.context_len - self.horizon_len = hparams.horizon_len - self.input_patch_len = hparams.patch_len - self.output_patch_len = hparams.horizon_len - self.num_layers = hparams.num_layers - self.model_dims = hparams.model_dim - self.backend = hparams.backend - self.quantiles = hparams.quantiles - self.num_heads = hparams.num_heads - - # Rewrite these values in __post_init__ for SPMD. - self.num_cores = 1 - self.per_core_batch_size = hparams.per_core_batch_size - self.global_batch_size = hparams.per_core_batch_size - - self._horizon_start = self.context_len - self.input_patch_len - self.__post_init__() - self.load_from_checkpoint(checkpoint) - - def load_from_checkpoint(self, checkpoint: TimesFmCheckpoint) -> None: - """Loads a checkpoint and compiles the decoder.""" - raise NotImplementedError("`load_from_checkpoint` is not implemented.") - - def _preprocess( - self, inputs: Sequence[np.array], freq: Sequence[int] - ) -> tuple[np.array, np.array, int]: - """Formats and pads raw inputs to feed into the model. - - This function both pads each time series to match the context length, and - pads the inputs to meet the SPMD shape requirement. - - Args: - inputs: A list of 1d JTensors. Each JTensor is the context time series of - a single forecast task. - freq: list of frequencies - - Returns: - A tuple of: - - the padded input time series to meet the model required context. - - the padding indicator. - - the number of padded examples for SPMD so that each core has the same - number (a multiple of `batch_size`) of examples. - """ - - input_ts, input_padding, inp_freq = [], [], [] - - pmap_pad = ( - (len(inputs) - 1) // self.global_batch_size + 1 - ) * self.global_batch_size - len(inputs) - - for i, ts in enumerate(inputs): - input_len = ts.shape[0] - padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float) - if input_len < self.context_len: - num_front_pad = self.context_len - input_len - ts = np.concatenate( - [np.zeros(shape=(num_front_pad,), dtype=float), ts], axis=0 - ) - padding = np.concatenate( - [np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0 - ) - elif input_len > self.context_len: - ts = ts[-self.context_len :] - padding = padding[-(self.context_len + self.horizon_len) :] - - input_ts.append(ts) - input_padding.append(padding) - inp_freq.append(freq[i]) - - # Padding the remainder batch. - for _ in range(pmap_pad): - input_ts.append(input_ts[-1]) - input_padding.append(input_padding[-1]) - inp_freq.append(inp_freq[-1]) - - return ( - np.stack(input_ts, axis=0), - np.stack(input_padding, axis=0), - np.array(inp_freq).astype(np.int32).reshape(-1, 1), - pmap_pad, - ) - - def forecast( - self, - inputs: Sequence[Any], - freq: Sequence[int] | None = None, - window_size: int | None = None, - forecast_context_len: int | None = None, - return_forecast_on_context: bool = False, - truncate_negative: bool = False, - ) -> tuple[np.array, np.array]: - """Forecasts on a list of time series. - - Args: - inputs: list of time series forecast contexts. Each context time series - should be in a format convertible to JTensor by `jnp.array`. - freq: frequency of each context time series. 0 for high frequency - (default), 1 for medium, and 2 for low. Notice this is different from - the `freq` required by `forecast_on_df`. - window_size: window size of trend + residual decomposition. If None then - we do not do decomposition. - forecast_context_len: optional max context length. - return_forecast_on_context: True to return the forecast on the context - when available, i.e. after the first input patch. - truncate_negative: truncate to only non-negative values if all the contexts - have non-negative values. - - Returns: - A tuple for JTensors: - - the mean forecast of size (# inputs, # forecast horizon), - - the full forecast (mean + quantiles) of size - (# inputs, # forecast horizon, 1 + # quantiles). - - Raises: - ValueError: If the checkpoint is not properly loaded. - """ - raise NotImplementedError("`forecast` is not implemented.") - - def forecast_on_df( - self, - inputs: pd.DataFrame, - freq: str, - forecast_context_len: int = 0, - value_name: str = "values", - model_name: str = "timesfm", - window_size: int | None = None, - num_jobs: int = 1, - verbose: bool = True, - ) -> pd.DataFrame: - """Forecasts on a list of time series. - - Args: - inputs: A pd.DataFrame of all time series. The dataframe should have a - `unique_id` column for identifying the time series, a `ds` column for - timestamps and a value column for the time series values. - freq: string valued `freq` of data. Notice this is different from the - `freq` required by `forecast`. See `freq_map` for allowed values. - forecast_context_len: If provided none zero, we take the last - `forecast_context_len` time-points from each series as the forecast - context instead of the `context_len` set by the model. - value_name: The name of the value column. - model_name: name of the model to be written into future df. - window_size: window size of trend + residual decomposition. If None then - we do not do decomposition. - num_jobs: number of parallel processes to use for dataframe processing. - verbose: output model states in terminal. - - Returns: - Future forecasts dataframe. - """ - if not ( - "unique_id" in inputs.columns - and "ds" in inputs.columns - and value_name in inputs.columns - ): - raise ValueError( - f"DataFrame must have unique_id, ds and {value_name} columns." - ) - if not forecast_context_len: - forecast_context_len = self.context_len - logging.info("Preprocessing dataframe.") - df_sorted = inputs.sort_values(by=["unique_id", "ds"]) - new_inputs = [] - uids = [] - if num_jobs == 1: - if verbose: - print("Processing dataframe with single process.") - for key, group in df_sorted.groupby("unique_id"): - inp, uid = process_group( - key, - group, - value_name, - forecast_context_len, - ) - new_inputs.append(inp) - uids.append(uid) - else: - if num_jobs == -1: - num_jobs = multiprocessing.cpu_count() - if verbose: - print("Processing dataframe with multiple processes.") - with multiprocessing.Pool(processes=num_jobs) as pool: - results = pool.starmap( - process_group, - [ - (key, group, value_name, forecast_context_len) - for key, group in df_sorted.groupby("unique_id") - ], - ) - new_inputs, uids = zip(*results) - if verbose: - print("Finished preprocessing dataframe.") - freq_inps = [freq_map(freq)] * len(new_inputs) - _, full_forecast = self.forecast( - new_inputs, freq=freq_inps, window_size=window_size - ) - if verbose: - print("Finished forecasting.") - fcst_df = make_future_dataframe( - uids=uids, - last_times=df_sorted.groupby("unique_id")["ds"].tail(1), - h=self.horizon_len, - freq=freq, - ) - fcst_df[model_name] = full_forecast[:, 0 : self.horizon_len, 0].reshape(-1, 1) - - for i, q in enumerate(self.quantiles): - q_col = f"{model_name}-q-{q}" - fcst_df[q_col] = full_forecast[:, 0 : self.horizon_len, 1 + i].reshape( - -1, 1 - ) - if q == 0.5: - fcst_df[model_name] = fcst_df[q_col] - logging.info("Finished creating output dataframe.") - return fcst_df From a52eeca07418e212f1a92ea0a0e52ebfbf45cfd1 Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Tue, 1 Oct 2024 17:15:11 -0700 Subject: [PATCH 096/242] restructing in progress --- src/transformers/models/timesfm/__init__.py | 2 +- .../models/timesfm/modeling_timesfm.py | 146 ------------------ 2 files changed, 1 insertion(+), 147 deletions(-) diff --git a/src/transformers/models/timesfm/__init__.py b/src/transformers/models/timesfm/__init__.py index fe1a08da2678..91f4693ae2e5 100644 --- a/src/transformers/models/timesfm/__init__.py +++ b/src/transformers/models/timesfm/__init__.py @@ -43,7 +43,7 @@ except OptionalDependencyNotAvailable: pass else: - from .timesfm import ( + from .modeling_timesfm import ( TimesFMModel, TimesFMPreTrainedModel, ) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index f2df0c061129..64ee4d5f8af6 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -330,152 +330,6 @@ def forecast_on_df( return fcst_df -class TimesFMModel(TimesFmBase, nn.Module): - """Body of the TimesFM model, excluding the head.""" - - def __post_init__(self): - self._model_config = TimesFMConfig( - num_layers=self.num_layers, - num_heads=self.num_heads, - hidden_size=self.model_dims, - intermediate_size=self.model_dims, - patch_len=self.input_patch_len, - horizon_len=self.output_patch_len, - head_dim=self.model_dims // self.num_heads, - quantiles=self.quantiles, - ) - - self.num_cores = 1 - self.global_batch_size = self.per_core_batch_size - self._device = torch.device( - "cuda:0" if (torch.cuda.is_available() and self.backend == "gpu") else "cpu" - ) - self._model = ppd.PatchedTimeSeriesDecoder(self._model_config) - self._model.to(self._device) - self._model.eval() - - def forecast( - self, - inputs: Sequence[Any], - freq: Sequence[int] | None = None, - window_size: int | None = None, - forecast_context_len: int | None = None, - return_forecast_on_context: bool = False, - truncate_negative: bool = False, - ) -> tuple[np.ndarray, np.ndarray]: - """Forecasts on a list of time series. - - Args: - inputs: list of time series forecast contexts. Each context time series - should be in a format convertible to JTensor by `jnp.array`. - freq: frequency of each context time series. 0 for high frequency - (default), 1 for medium, and 2 for low. Notice this is different from - the `freq` required by `forecast_on_df`. - window_size: window size of trend + residual decomposition. If None then - we do not do decomposition. - forecast_context_len: optional max context length. - return_forecast_on_context: True to return the forecast on the context - when available, i.e. after the first input patch. - truncate_negative: truncate to only non-negative values if all the contexts - have non-negative values. - - Returns: - A tuple for JTensors: - - the mean forecast of size (# inputs, # forecast horizon), - - the full forecast (mean + quantiles) of size - (# inputs, # forecast horizon, 1 + # quantiles). - - Raises: - ValueError: If the checkpoint is not properly loaded. - """ - if forecast_context_len is None: - fcontext_len = self.context_len - else: - fcontext_len = forecast_context_len - inputs = [np.array(ts)[-fcontext_len:] for ts in inputs] - inp_min = np.min([np.min(ts) for ts in inputs]) - - if window_size is not None: - new_inputs = [] - for ts in inputs: - new_inputs.extend(moving_average(ts, window_size)) - inputs = new_inputs - - if freq is None: - logging.info("No frequency provided via `freq`. Default to high (0).") - freq = [0] * len(inputs) - - input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq) - with torch.no_grad(): - mean_outputs = [] - full_outputs = [] - assert input_ts.shape[0] % self.global_batch_size == 0 - for i in range(input_ts.shape[0] // self.global_batch_size): - input_ts_in = torch.from_numpy( - np.array( - input_ts[ - i - * self.global_batch_size : (i + 1) - * self.global_batch_size - ], - dtype=np.float32, - ) - ).to(self._device) - input_padding_in = torch.from_numpy( - np.array( - input_padding[ - i - * self.global_batch_size : (i + 1) - * self.global_batch_size - ], - dtype=np.float32, - ) - ).to(self._device) - inp_freq_in = ( - torch.from_numpy( - np.array( - inp_freq[ - i - * self.global_batch_size : (i + 1) - * self.global_batch_size, - :, - ], - dtype=np.int32, - ) - ) - .long() - .to(self._device) - ) - mean_output, full_output = self._model.decode( - input_ts=input_ts_in, - paddings=input_padding_in, - freq=inp_freq_in, - horizon_len=self.horizon_len, - return_forecast_on_context=return_forecast_on_context, - ) - mean_output = mean_output.detach().cpu().numpy() - full_output = full_output.detach().cpu().numpy() - mean_output = np.array(mean_output) - full_output = np.array(full_output) - mean_outputs.append(mean_output) - full_outputs.append(full_output) - - mean_outputs = np.concatenate(mean_outputs, axis=0) - full_outputs = np.concatenate(full_outputs, axis=0) - - if pmap_pad > 0: - mean_outputs = mean_outputs[:-pmap_pad, ...] - full_outputs = full_outputs[:-pmap_pad, ...] - - if window_size is not None: - mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] - full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] - if inp_min >= 0 and truncate_negative: - mean_outputs = np.maximum(mean_outputs, 0.0) - full_outputs = np.maximum(full_outputs, 0.0) - return mean_outputs, full_outputs - - class TimesFMModel(TimesFmBase, nn.Module): """TimesFM forecast API for inference.""" From c7f760ea1ca9c067a72d81501746373dbc929ebe Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Wed, 2 Oct 2024 17:43:13 -0700 Subject: [PATCH 097/242] adapted to HF convention --- src/transformers/__init__.py | 2 - .../models/timesfm/modeling_timesfm.py | 510 +++++++++--------- .../{patched_decoder.py => timesfm_layers.py} | 254 ++------- 3 files changed, 305 insertions(+), 461 deletions(-) rename src/transformers/models/timesfm/{patched_decoder.py => timesfm_layers.py} (66%) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b46578c461bf..5cee1d88630f 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3710,7 +3710,6 @@ ) _import_structure["models.timesfm"].extend( [ - "TimesFMForPrediction", "TimesFMModel", "TimesFMPreTrainedModel", ] @@ -8455,7 +8454,6 @@ TimeSeriesTransformerPreTrainedModel, ) from .models.timesfm import ( - TimesFMForPrediction, TimesFMModel, TimesFMPreTrainedModel, ) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 64ee4d5f8af6..adb5a9d31006 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -22,119 +22,290 @@ #################################################### -import dataclasses import logging import multiprocessing from typing import Any, Sequence -from os import path import pandas as pd import numpy as np import torch import torch.nn as nn -from huggingface_hub import snapshot_download from ...modeling_utils import PreTrainedModel +from .configuration_timesfm import TimesFMConfig +from .timesfm_layers import * -import patched_decoder as ppd +# TODO: shall remove this dependency after API design is finalized. from utilsforecast.processing import make_future_dataframe -from configuration_timesfm import TimesFMConfig - - -def process_group(key, group, value_name, forecast_context_len): - group = group.tail(forecast_context_len) - return np.array(group[value_name], dtype=np.float32), key - - -def moving_average(arr, window_size): - """Calculates the moving average using NumPy's convolution function.""" - # Pad with zeros to handle initial window positions - arr_padded = np.pad(arr, (window_size - 1, 0), "constant") - smoothed_arr = np.convolve(arr_padded, np.ones(window_size), "valid") / window_size - return [smoothed_arr, arr - smoothed_arr] - - -def freq_map(freq: str): - """Returns the frequency map for the given frequency string.""" - freq = str.upper(freq) - if ( - freq.endswith("H") - or freq.endswith("T") - or freq.endswith("MIN") - or freq.endswith("D") - or freq.endswith("B") - or freq.endswith("U") - ): - return 0 - elif freq.endswith(("W", "M", "MS")): - return 1 - elif freq.endswith("Y") or freq.endswith("Q"): - return 2 - else: - raise ValueError(f"Invalid frequency: {freq}") - - -@dataclasses.dataclass(kw_only=True) -class TimesFmCheckpoint: - """Checkpoint used to initialize a TimesFM model for inference. - - Attributes: - version: Version of the checkpoint, e.g. "jax", "torch", "tensorflow", etc. - The factory will create the corresponding TimesFm inference class based on - this version. - path: Path to the checkpoint. - type: If provided, type of the checkpoint used by the specific checkpoint - loader per version. - step: If provided, step of the checkpoint. - """ - - version: str = "torch" - path: str | None = None - huggingface_repo_id: str | None = None - type: Any = None - step: int | None = None - - -class TimesFmBase: - """Base TimesFM forecast API for inference. - - This class is the scaffolding for calling TimesFM forecast. To properly use: - 1. Create an instance with the correct hyperparameters of a TimesFM model. - 2. Call `load_from_checkpoint` to load a compatible checkpoint. - 3. Call `forecast` for inference. - """ - - def _logging(self, s): - print(s) - - def __init__(self, hparams: TimesFMConfig) -> None: - """Initializes the TimesFM forecast API. + + +class TimesFMPreTrainedModel(PreTrainedModel): + """handles the loading for all models.""" + + config_class = TimesFMConfig + base_model_prefix = "timesfm" + + def _init_weights(self, module): + if isinstance(module, nn.Embedding): + nn.init.uniform_(module.weight, a=-0.1, b=0.1) + + elif isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + + elif isinstance(module, nn.LayerNorm): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + + elif isinstance(module, RMSNorm): + nn.init.zeros_(module.weight) + + elif isinstance(module, PositionalEmbedding): + pass + + +class PatchedTimeSeriesDecoder(TimesFMPreTrainedModel): + """Patched time-series decoder.""" + + def __init__(self, config: TimesFMConfig): + super().__init__(config) + + self.config = config + self.input_ff_layer = ResidualBlock( + input_dims=2 * config.patch_len, + output_dims=config.model_dim, + hidden_dims=config.model_dim, + ) + self.freq_emb = nn.Embedding( + num_embeddings=config.freq_size, embedding_dim=config.model_dim + ) + self.horizon_ff_layer = ResidualBlock( + input_dims=config.model_dim, + output_dims=config.horizon_len * (1 + len(config.quantiles)), + hidden_dims=config.model_dim, + ) + self.stacked_transformer = StackedDecoder( + hidden_size=self.config.model_dim, + intermediate_size=self.config.model_dim, + num_heads=self.config.num_heads, + num_kv_heads=self.config.num_heads, + head_dim=self.config.head_dim, + num_layers=self.config.num_layers, + rms_norm_eps=self.config.rms_norm_eps, + ) + if self.config.use_positional_embedding: + self.position_emb = PositionalEmbedding( + embedding_dims=self.config.model_dim, + ) + + def _forward_transform( + self, inputs: torch.Tensor, patched_pads: torch.Tensor + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """Input is of shape [B, N, P].""" + mu, sigma = masked_mean_std(inputs, patched_pads) + sigma = torch.where( + sigma < self.config.tolerance, + torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device), + sigma, + ) + + # Normalize each patch + outputs = (inputs - mu[:, None, None]) / sigma[:, None, None] + outputs = torch.where( + torch.abs(inputs - self.config.pad_val) < self.config.tolerance, + torch.tensor( + self.config.pad_val, dtype=outputs.dtype, device=outputs.device + ), + outputs, + ) + return outputs, (mu, sigma) + + def _reverse_transform( + self, outputs: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor] + ) -> torch.Tensor: + """Output is of shape [B, N, P, Q].""" + mu, sigma = stats + return outputs * sigma[:, None, None, None] + mu[:, None, None, None] + + def _preprocess_input( + self, + input_ts: torch.Tensor, + input_padding: torch.Tensor, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + tuple[torch.Tensor, torch.Tensor] | None, + torch.Tensor, + ]: + """Preprocess input for stacked transformer.""" + + # Reshape into patches (using view for efficiency) + bsize = input_ts.shape[0] + patched_inputs = input_ts.view(bsize, -1, self.config.patch_len) + patched_pads = input_padding.view(bsize, -1, self.config.patch_len) + + patched_inputs = torch.where( + torch.abs(patched_pads - 1.0) < self.config.tolerance, + torch.tensor(0.0, dtype=patched_inputs.dtype, device=patched_inputs.device), + patched_inputs, + ) + patched_pads = torch.where( + torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance, + torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device), + patched_pads, + ) + patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads) + + # B x N x D + patched_inputs = patched_inputs * (1.0 - patched_pads) + concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1) + model_input = self.input_ff_layer(concat_inputs) + + # A patch should not be padded even if there is at least one zero. + patched_padding = torch.min(patched_pads, dim=-1)[ + 0 + ] # Get the values from the min result + if self.config.use_positional_embedding: + pos_emb = self.position_emb(model_input.shape[1]).to(model_input.device) + pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0) + pos_emb = shift_padded_seq(patched_padding, pos_emb) + model_input += pos_emb + + return model_input, patched_padding, stats, patched_inputs + + def _postprocess_output( + self, + model_output: torch.Tensor, + num_outputs: int, + stats: tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + """Postprocess output of stacked transformer.""" + + # B x N x (H.Q) + output_ts = self.horizon_ff_layer(model_output) + + # Reshape using view + b, n, _ = output_ts.shape + output_ts = output_ts.view(b, n, self.config.horizon_len, num_outputs) + + return self._reverse_transform(output_ts, stats) + + def forward( + self, + input_ts: torch.Tensor, + input_padding: torch.LongTensor, + freq: torch.Tensor, + ) -> torch.Tensor: + num_outputs = len(self.config.quantiles) + 1 + model_input, patched_padding, stats, _ = self._preprocess_input( + input_ts=input_ts, + input_padding=input_padding, + ) + f_emb = self.freq_emb(freq) # B x 1 x D + model_input += f_emb + model_output = self.stacked_transformer(model_input, patched_padding) + + output_ts = self._postprocess_output(model_output, num_outputs, stats) + return output_ts + + def decode( + self, + input_ts: torch.Tensor, + paddings: torch.Tensor, + freq: torch.LongTensor, + horizon_len: int, + output_patch_len: int | None = None, + max_len: int = 512, + return_forecast_on_context: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Auto-regressive decoding without caching. Args: - hparams: Hyperparameters of the model. - checkpoint: Checkpoint to load. Notice `checkpoint.version` will decide - which TimesFM version to use. + input_ts: input time-series and paddings. Time-series shape B x C. + paddings: padding shape B x (C + H) where H is the prediction length. + freq: frequency shape B x 1 + horizon_len: prediction length. + output_patch_len: output length to be fetched from one step of + auto-regressive decoding. + max_len: maximum training context length. + return_forecast_on_context: whether to return the model forecast on the + context except the first input patch. + + Returns: + Tuple of two forecasting results: + - Point (mean) output predictions as a tensor with shape B x H'. + - Full predictions (mean and quantiles) as a tensor with shape + B x H' x (1 + # quantiles). + In particular, if return_forecast_on_context is True, H' is H plus + the forecastable context length, i.e. context_len - (first) patch_len. """ - self.hparams = hparams - - # Expand hparams for conciseness within the model code. - self.context_len = hparams.context_len - self.horizon_len = hparams.horizon_len - self.input_patch_len = hparams.patch_len - self.output_patch_len = hparams.horizon_len - self.num_layers = hparams.num_layers - self.model_dims = hparams.model_dim - self.backend = hparams.backend - self.quantiles = hparams.quantiles - self.num_heads = hparams.num_heads - - # Rewrite these values in subclasses for SPMD. + final_out = input_ts + context_len = final_out.shape[1] + full_outputs = [] + if paddings.shape[1] != final_out.shape[1] + horizon_len: + raise ValueError( + "Length of paddings must match length of input + horizon_len:" + f" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}" + ) + if output_patch_len is None: + output_patch_len = self.config.horizon_len + num_decode_patches = (horizon_len + output_patch_len - 1) // output_patch_len + for step_index in range(num_decode_patches): + current_padding = paddings[:, 0 : final_out.shape[1]] + input_ts = final_out[:, -max_len:] + input_padding = current_padding[:, -max_len:] + fprop_outputs = self(input_ts, input_padding, freq) + if return_forecast_on_context and step_index == 0: + # For the first decodings step, collect the model forecast on the + # context except the unavailable first input batch forecast. + new_full_ts = fprop_outputs[:, :-1, : self.config.patch_len, :] + new_full_ts = fprop_outputs.view( + new_full_ts.size(0), -1, new_full_ts.size(3) + ) + + full_outputs.append(new_full_ts) + + # (full batch, last patch, output_patch_len, index of mean forecast = 0) + new_ts = fprop_outputs[:, -1, :output_patch_len, 0] + new_full_ts = fprop_outputs[:, -1, :output_patch_len, :] + # (full batch, last patch, output_patch_len, all output indices) + full_outputs.append(new_full_ts) + final_out = torch.concatenate([final_out, new_ts], axis=-1) + + if return_forecast_on_context: + # `full_outputs` indexing starts at after the first input patch. + full_outputs = torch.concatenate(full_outputs, axis=1)[ + :, : (context_len - self.config.patch_len + horizon_len), : + ] + else: + # `full_outputs` indexing starts at the forecast horizon. + full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0:horizon_len, :] + + return (full_outputs[:, :, 0], full_outputs) + + +class TimesFMModel(TimesFMPreTrainedModel): + def __init__(self, config: TimesFMConfig): + super().__init__(config) + + self.config = config + + self.decoder = PatchedTimeSeriesDecoder(config) + + self.context_len = config.context_len + self.horizon_len = config.horizon_len + self.input_patch_len = config.patch_len + self.output_patch_len = config.horizon_len + self.num_layers = config.num_layers + self.model_dims = config.model_dim + self.backend = config.backend + self.quantiles = config.quantiles + self.num_heads = config.num_heads + self.num_cores = 1 - self.per_core_batch_size = hparams.per_core_batch_size - self.global_batch_size = hparams.per_core_batch_size + self.per_core_batch_size = config.per_core_batch_size + self.global_batch_size = config.per_core_batch_size * self.num_cores self._horizon_start = self.context_len - self.input_patch_len - - def load_from_checkpoint(self, checkpoint: TimesFmCheckpoint) -> None: - """Loads a checkpoint and compiles the decoder.""" - raise NotImplementedError("`load_from_checkpoint` is not implemented.") + self._device = config.backend def _preprocess( self, inputs: Sequence[np.array], freq: Sequence[int] @@ -329,152 +500,9 @@ def forecast_on_df( logging.info("Finished creating output dataframe.") return fcst_df - -class TimesFMModel(TimesFmBase, nn.Module): - """TimesFM forecast API for inference.""" - - def __init__(self, hparams: TimesFMConfig) -> None: - super.__init__(hparams) - self._model_config = hparams - self._model = ppd.PatchedTimeSeriesDecoder(self._model_config) - self.num_cores = 1 - self.global_batch_size = self.per_core_batch_size - self._device = torch.device( - "cuda:0" if (torch.cuda.is_available() and self.backend == "gpu") else "cpu" - ) - - def forecast( - self, - inputs: Sequence[Any], - freq: Sequence[int] | None = None, - window_size: int | None = None, - forecast_context_len: int | None = None, - return_forecast_on_context: bool = False, - truncate_negative: bool = False, - ) -> tuple[np.ndarray, np.ndarray]: - """Forecasts on a list of time series. - - Args: - inputs: list of time series forecast contexts. Each context time series - should be in a format convertible to JTensor by `jnp.array`. - freq: frequency of each context time series. 0 for high frequency - (default), 1 for medium, and 2 for low. Notice this is different from - the `freq` required by `forecast_on_df`. - window_size: window size of trend + residual decomposition. If None then - we do not do decomposition. - forecast_context_len: optional max context length. - return_forecast_on_context: True to return the forecast on the context - when available, i.e. after the first input patch. - truncate_negative: truncate to only non-negative values if all the contexts - have non-negative values. - - Returns: - A tuple for JTensors: - - the mean forecast of size (# inputs, # forecast horizon), - - the full forecast (mean + quantiles) of size - (# inputs, # forecast horizon, 1 + # quantiles). - - Raises: - ValueError: If the checkpoint is not properly loaded. - """ - if not self._model: - raise ValueError( - "Checkpoint not loaded. Call `load_from_checkpoint` before" - " `forecast`." - ) - if forecast_context_len is None: - fcontext_len = self.context_len - else: - fcontext_len = forecast_context_len - inputs = [np.array(ts)[-fcontext_len:] for ts in inputs] - inp_min = np.min([np.min(ts) for ts in inputs]) - - if window_size is not None: - new_inputs = [] - for ts in inputs: - new_inputs.extend(moving_average(ts, window_size)) - inputs = new_inputs - - if freq is None: - logging.info("No frequency provided via `freq`. Default to high (0).") - freq = [0] * len(inputs) - - input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq) - with torch.no_grad(): - mean_outputs = [] - full_outputs = [] - assert input_ts.shape[0] % self.global_batch_size == 0 - for i in range(input_ts.shape[0] // self.global_batch_size): - input_ts_in = torch.from_numpy( - np.array( - input_ts[ - i - * self.global_batch_size : (i + 1) - * self.global_batch_size - ], - dtype=np.float32, - ) - ).to(self._device) - input_padding_in = torch.from_numpy( - np.array( - input_padding[ - i - * self.global_batch_size : (i + 1) - * self.global_batch_size - ], - dtype=np.float32, - ) - ).to(self._device) - inp_freq_in = ( - torch.from_numpy( - np.array( - inp_freq[ - i - * self.global_batch_size : (i + 1) - * self.global_batch_size, - :, - ], - dtype=np.int32, - ) - ) - .long() - .to(self._device) - ) - mean_output, full_output = self._model.decode( - input_ts=input_ts_in, - paddings=input_padding_in, - freq=inp_freq_in, - horizon_len=self.horizon_len, - return_forecast_on_context=return_forecast_on_context, - ) - mean_output = mean_output.detach().cpu().numpy() - full_output = full_output.detach().cpu().numpy() - mean_output = np.array(mean_output) - full_output = np.array(full_output) - mean_outputs.append(mean_output) - full_outputs.append(full_output) - - mean_outputs = np.concatenate(mean_outputs, axis=0) - full_outputs = np.concatenate(full_outputs, axis=0) - - if pmap_pad > 0: - mean_outputs = mean_outputs[:-pmap_pad, ...] - full_outputs = full_outputs[:-pmap_pad, ...] - - if window_size is not None: - mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] - full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] - if inp_min >= 0 and truncate_negative: - mean_outputs = np.maximum(mean_outputs, 0.0) - full_outputs = np.maximum(full_outputs, 0.0) - return mean_outputs, full_outputs - def forward(self, x, **kwargs): if isinstance(x, pd.DataFrame): assert "freq" in kwargs, "Frequency must be provided for DataFrame input." return self.forecast_on_df(x, **kwargs) else: return self.forecast(x, **kwargs) - - -## TODO: Define the PreTrainedTimesFMModel class diff --git a/src/transformers/models/timesfm/patched_decoder.py b/src/transformers/models/timesfm/timesfm_layers.py similarity index 66% rename from src/transformers/models/timesfm/patched_decoder.py rename to src/transformers/models/timesfm/timesfm_layers.py index baafe6be148c..713623eb98a7 100644 --- a/src/transformers/models/timesfm/patched_decoder.py +++ b/src/transformers/models/timesfm/timesfm_layers.py @@ -17,13 +17,13 @@ import math from typing import List, Tuple +import numpy as np import torch from torch import nn import torch.nn.functional as F -from transformers.models.timesfm.configuration_timesfm import TimesFMConfig -def _masked_mean_std( +def masked_mean_std( inputs: torch.Tensor, padding: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """Calculates mean and standard deviation of `inputs` across axis 1. @@ -81,7 +81,7 @@ def _get_patch_index(arr: torch.Tensor): return masked_mean, masked_std -def _shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor: +def shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor: """Shifts rows of seq based on the first 0 in each row of the mask. Args: @@ -213,6 +213,39 @@ def expand_t(key_mask): return torch.minimum(a, b) # Element-wise minimum, similar to jnp.minimum +def process_group(key, group, value_name, forecast_context_len): + group = group.tail(forecast_context_len) + return np.array(group[value_name], dtype=np.float32), key + + +def moving_average(arr, window_size): + """Calculates the moving average using NumPy's convolution function.""" + # Pad with zeros to handle initial window positions + arr_padded = np.pad(arr, (window_size - 1, 0), "constant") + smoothed_arr = np.convolve(arr_padded, np.ones(window_size), "valid") / window_size + return [smoothed_arr, arr - smoothed_arr] + + +def freq_map(freq: str): + """Returns the frequency map for the given frequency string.""" + freq = str.upper(freq) + if ( + freq.endswith("H") + or freq.endswith("T") + or freq.endswith("MIN") + or freq.endswith("D") + or freq.endswith("B") + or freq.endswith("U") + ): + return 0 + elif freq.endswith(("W", "M", "MS")): + return 1 + elif freq.endswith("Y") or freq.endswith("Q"): + return 2 + else: + raise ValueError(f"Invalid frequency: {freq}") + + class ResidualBlock(nn.Module): """TimesFM residual block.""" @@ -547,218 +580,3 @@ def forward(self, seq_length=None, position=None): # Padding to ensure correct embedding dimension signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) return signal - - -class PatchedTimeSeriesDecoder(nn.Module): - """Patched time-series decoder.""" - - def __init__(self, config: TimesFMConfig): - super().__init__() - self.config = config - self.input_ff_layer = ResidualBlock( - input_dims=2 * config.patch_len, - output_dims=config.model_dim, - hidden_dims=config.model_dim, - ) - self.freq_emb = nn.Embedding(num_embeddings=3, embedding_dim=config.model_dim) - self.horizon_ff_layer = ResidualBlock( - input_dims=config.model_dim, - output_dims=config.horizon_len * (1 + len(config.quantiles)), - hidden_dims=config.model_dim, - ) - self.stacked_transformer = StackedDecoder( - hidden_size=self.config.model_dim, - intermediate_size=self.config.model_dim, - num_heads=self.config.num_heads, - num_kv_heads=self.config.num_heads, - head_dim=self.config.head_dim, - num_layers=self.config.num_layers, - rms_norm_eps=self.config.rms_norm_eps, - ) - if self.config.use_positional_embedding: - self.position_emb = PositionalEmbedding(self.config.model_dim) - - def _forward_transform( - self, inputs: torch.Tensor, patched_pads: torch.Tensor - ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - """Input is of shape [B, N, P].""" - mu, sigma = _masked_mean_std(inputs, patched_pads) - sigma = torch.where( - sigma < self.config.tolerance, - torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device), - sigma, - ) - - # Normalize each patch - outputs = (inputs - mu[:, None, None]) / sigma[:, None, None] - outputs = torch.where( - torch.abs(inputs - self.config.pad_val) < self.config.tolerance, - torch.tensor( - self.config.pad_val, dtype=outputs.dtype, device=outputs.device - ), - outputs, - ) - return outputs, (mu, sigma) - - def _reverse_transform( - self, outputs: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor] - ) -> torch.Tensor: - """Output is of shape [B, N, P, Q].""" - mu, sigma = stats - return outputs * sigma[:, None, None, None] + mu[:, None, None, None] - - def _preprocess_input( - self, - input_ts: torch.Tensor, - input_padding: torch.Tensor, - ) -> tuple[ - torch.Tensor, - torch.Tensor, - tuple[torch.Tensor, torch.Tensor] | None, - torch.Tensor, - ]: - """Preprocess input for stacked transformer.""" - - # Reshape into patches (using view for efficiency) - bsize = input_ts.shape[0] - patched_inputs = input_ts.view(bsize, -1, self.config.patch_len) - patched_pads = input_padding.view(bsize, -1, self.config.patch_len) - - patched_inputs = torch.where( - torch.abs(patched_pads - 1.0) < self.config.tolerance, - torch.tensor(0.0, dtype=patched_inputs.dtype, device=patched_inputs.device), - patched_inputs, - ) - patched_pads = torch.where( - torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance, - torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device), - patched_pads, - ) - patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads) - - # B x N x D - patched_inputs = patched_inputs * (1.0 - patched_pads) - concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1) - model_input = self.input_ff_layer(concat_inputs) - - # A patch should not be padded even if there is at least one zero. - patched_padding = torch.min(patched_pads, dim=-1)[ - 0 - ] # Get the values from the min result - if self.config.use_positional_embedding: - pos_emb = self.position_emb(model_input.shape[1]).to(model_input.device) - pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0) - pos_emb = _shift_padded_seq(patched_padding, pos_emb) - model_input += pos_emb - - return model_input, patched_padding, stats, patched_inputs - - def _postprocess_output( - self, - model_output: torch.Tensor, - num_outputs: int, - stats: tuple[torch.Tensor, torch.Tensor], - ) -> torch.Tensor: - """Postprocess output of stacked transformer.""" - - # B x N x (H.Q) - output_ts = self.horizon_ff_layer(model_output) - - # Reshape using view - b, n, _ = output_ts.shape - output_ts = output_ts.view(b, n, self.config.horizon_len, num_outputs) - - return self._reverse_transform(output_ts, stats) - - def forward( - self, - input_ts: torch.Tensor, - input_padding: torch.LongTensor, - freq: torch.Tensor, - ) -> torch.Tensor: - num_outputs = len(self.config.quantiles) + 1 - model_input, patched_padding, stats, _ = self._preprocess_input( - input_ts=input_ts, - input_padding=input_padding, - ) - f_emb = self.freq_emb(freq) # B x 1 x D - model_input += f_emb - model_output = self.stacked_transformer(model_input, patched_padding) - - output_ts = self._postprocess_output(model_output, num_outputs, stats) - return output_ts - - def decode( - self, - input_ts: torch.Tensor, - paddings: torch.Tensor, - freq: torch.LongTensor, - horizon_len: int, - output_patch_len: int | None = None, - max_len: int = 512, - return_forecast_on_context: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Auto-regressive decoding without caching. - - Args: - input_ts: input time-series and paddings. Time-series shape B x C. - paddings: padding shape B x (C + H) where H is the prediction length. - freq: frequency shape B x 1 - horizon_len: prediction length. - output_patch_len: output length to be fetched from one step of - auto-regressive decoding. - max_len: maximum training context length. - return_forecast_on_context: whether to return the model forecast on the - context except the first input patch. - - Returns: - Tuple of two forecasting results: - - Point (mean) output predictions as a tensor with shape B x H'. - - Full predictions (mean and quantiles) as a tensor with shape - B x H' x (1 + # quantiles). - In particular, if return_forecast_on_context is True, H' is H plus - the forecastable context length, i.e. context_len - (first) patch_len. - """ - final_out = input_ts - context_len = final_out.shape[1] - full_outputs = [] - if paddings.shape[1] != final_out.shape[1] + horizon_len: - raise ValueError( - "Length of paddings must match length of input + horizon_len:" - f" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}" - ) - if output_patch_len is None: - output_patch_len = self.config.horizon_len - num_decode_patches = (horizon_len + output_patch_len - 1) // output_patch_len - for step_index in range(num_decode_patches): - current_padding = paddings[:, 0 : final_out.shape[1]] - input_ts = final_out[:, -max_len:] - input_padding = current_padding[:, -max_len:] - fprop_outputs = self(input_ts, input_padding, freq) - if return_forecast_on_context and step_index == 0: - # For the first decodings step, collect the model forecast on the - # context except the unavailable first input batch forecast. - new_full_ts = fprop_outputs[:, :-1, : self.config.patch_len, :] - new_full_ts = fprop_outputs.view( - new_full_ts.size(0), -1, new_full_ts.size(3) - ) - - full_outputs.append(new_full_ts) - - # (full batch, last patch, output_patch_len, index of mean forecast = 0) - new_ts = fprop_outputs[:, -1, :output_patch_len, 0] - new_full_ts = fprop_outputs[:, -1, :output_patch_len, :] - # (full batch, last patch, output_patch_len, all output indices) - full_outputs.append(new_full_ts) - final_out = torch.concatenate([final_out, new_ts], axis=-1) - - if return_forecast_on_context: - # `full_outputs` indexing starts at after the first input patch. - full_outputs = torch.concatenate(full_outputs, axis=1)[ - :, : (context_len - self.config.patch_len + horizon_len), : - ] - else: - # `full_outputs` indexing starts at the forecast horizon. - full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0:horizon_len, :] - - return (full_outputs[:, :, 0], full_outputs) From d71713239da3ebc1be855c82d18e8e26b56b8194 Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Wed, 9 Oct 2024 17:48:35 -0700 Subject: [PATCH 098/242] timesfm test --- .../models/timesfm/modeling_timesfm.py | 195 ++- tests/models/timesfm/test_modeling_timesfm.py | 1347 +---------------- 2 files changed, 170 insertions(+), 1372 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index adb5a9d31006..612493e7f7dd 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -23,7 +23,6 @@ import logging -import multiprocessing from typing import Any, Sequence import pandas as pd import numpy as np @@ -33,9 +32,6 @@ from .configuration_timesfm import TimesFMConfig from .timesfm_layers import * -# TODO: shall remove this dependency after API design is finalized. -from utilsforecast.processing import make_future_dataframe - class TimesFMPreTrainedModel(PreTrainedModel): """handles the loading for all models.""" @@ -366,7 +362,7 @@ def _preprocess( pmap_pad, ) - def forecast( + def forward( self, inputs: Sequence[Any], freq: Sequence[int] | None = None, @@ -374,7 +370,7 @@ def forecast( forecast_context_len: int | None = None, return_forecast_on_context: bool = False, truncate_negative: bool = False, - ) -> tuple[np.array, np.array]: + ) -> tuple[np.ndarray, np.ndarray]: """Forecasts on a list of time series. Args: @@ -400,109 +396,90 @@ def forecast( Raises: ValueError: If the checkpoint is not properly loaded. """ - raise NotImplementedError("`forecast` is not implemented.") - - def forecast_on_df( - self, - inputs: pd.DataFrame, - freq: str, - forecast_context_len: int = 0, - value_name: str = "values", - model_name: str = "timesfm", - window_size: int | None = None, - num_jobs: int = 1, - verbose: bool = True, - ) -> pd.DataFrame: - """Forecasts on a list of time series. - Args: - inputs: A pd.DataFrame of all time series. The dataframe should have a - `unique_id` column for identifying the time series, a `ds` column for - timestamps and a value column for the time series values. - freq: string valued `freq` of data. Notice this is different from the - `freq` required by `forecast`. See `freq_map` for allowed values. - forecast_context_len: If provided none zero, we take the last - `forecast_context_len` time-points from each series as the forecast - context instead of the `context_len` set by the model. - value_name: The name of the value column. - model_name: name of the model to be written into future df. - window_size: window size of trend + residual decomposition. If None then - we do not do decomposition. - num_jobs: number of parallel processes to use for dataframe processing. - verbose: output model states in terminal. - - Returns: - Future forecasts dataframe. - """ - if not ( - "unique_id" in inputs.columns - and "ds" in inputs.columns - and value_name in inputs.columns - ): - raise ValueError( - f"DataFrame must have unique_id, ds and {value_name} columns." - ) - if not forecast_context_len: - forecast_context_len = self.context_len - logging.info("Preprocessing dataframe.") - df_sorted = inputs.sort_values(by=["unique_id", "ds"]) - new_inputs = [] - uids = [] - if num_jobs == 1: - if verbose: - print("Processing dataframe with single process.") - for key, group in df_sorted.groupby("unique_id"): - inp, uid = process_group( - key, - group, - value_name, - forecast_context_len, - ) - new_inputs.append(inp) - uids.append(uid) + if forecast_context_len is None: + fcontext_len = self.context_len else: - if num_jobs == -1: - num_jobs = multiprocessing.cpu_count() - if verbose: - print("Processing dataframe with multiple processes.") - with multiprocessing.Pool(processes=num_jobs) as pool: - results = pool.starmap( - process_group, - [ - (key, group, value_name, forecast_context_len) - for key, group in df_sorted.groupby("unique_id") - ], + fcontext_len = forecast_context_len + inputs = [np.array(ts)[-fcontext_len:] for ts in inputs] + inp_min = np.min([np.min(ts) for ts in inputs]) + + if window_size is not None: + new_inputs = [] + for ts in inputs: + new_inputs.extend(moving_average(ts, window_size)) + inputs = new_inputs + + if freq is None: + logging.info("No frequency provided via `freq`. Default to high (0).") + freq = [0] * len(inputs) + + input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq) + with torch.no_grad(): + mean_outputs = [] + full_outputs = [] + assert input_ts.shape[0] % self.global_batch_size == 0 + for i in range(input_ts.shape[0] // self.global_batch_size): + input_ts_in = torch.from_numpy( + np.array( + input_ts[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size + ], + dtype=np.float32, + ) + ).to(self._device) + input_padding_in = torch.from_numpy( + np.array( + input_padding[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size + ], + dtype=np.float32, + ) + ).to(self._device) + inp_freq_in = ( + torch.from_numpy( + np.array( + inp_freq[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size, + :, + ], + dtype=np.int32, + ) + ) + .long() + .to(self._device) ) - new_inputs, uids = zip(*results) - if verbose: - print("Finished preprocessing dataframe.") - freq_inps = [freq_map(freq)] * len(new_inputs) - _, full_forecast = self.forecast( - new_inputs, freq=freq_inps, window_size=window_size - ) - if verbose: - print("Finished forecasting.") - fcst_df = make_future_dataframe( - uids=uids, - last_times=df_sorted.groupby("unique_id")["ds"].tail(1), - h=self.horizon_len, - freq=freq, - ) - fcst_df[model_name] = full_forecast[:, 0 : self.horizon_len, 0].reshape(-1, 1) - - for i, q in enumerate(self.quantiles): - q_col = f"{model_name}-q-{q}" - fcst_df[q_col] = full_forecast[:, 0 : self.horizon_len, 1 + i].reshape( - -1, 1 - ) - if q == 0.5: - fcst_df[model_name] = fcst_df[q_col] - logging.info("Finished creating output dataframe.") - return fcst_df - - def forward(self, x, **kwargs): - if isinstance(x, pd.DataFrame): - assert "freq" in kwargs, "Frequency must be provided for DataFrame input." - return self.forecast_on_df(x, **kwargs) - else: - return self.forecast(x, **kwargs) + mean_output, full_output = self.decoder.decode( + input_ts=input_ts_in, + paddings=input_padding_in, + freq=inp_freq_in, + horizon_len=self.horizon_len, + return_forecast_on_context=return_forecast_on_context, + ) + mean_output = mean_output.detach().cpu().numpy() + full_output = full_output.detach().cpu().numpy() + mean_output = np.array(mean_output) + full_output = np.array(full_output) + mean_outputs.append(mean_output) + full_outputs.append(full_output) + + mean_outputs = np.concatenate(mean_outputs, axis=0) + full_outputs = np.concatenate(full_outputs, axis=0) + + if pmap_pad > 0: + mean_outputs = mean_outputs[:-pmap_pad, ...] + full_outputs = full_outputs[:-pmap_pad, ...] + + if window_size is not None: + mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] + full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] + if inp_min >= 0 and truncate_negative: + mean_outputs = np.maximum(mean_outputs, 0.0) + full_outputs = np.maximum(full_outputs, 0.0) + return mean_outputs, full_outputs diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index e08277fac50f..2aafe4133c8d 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -14,11 +14,9 @@ # limitations under the License. -import copy -import os -import pickle -import tempfile +import numpy as np import unittest +from typing import List from transformers import TimesFMConfig, is_torch_available from transformers.testing_utils import ( @@ -33,7 +31,7 @@ from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor +from ...test_modeling_common import ModelTesterMixin from ...test_pipeline_mixin import PipelineTesterMixin @@ -46,10 +44,7 @@ from transformers import ( AutoTokenizer, - ByT5Tokenizer, - TimesFMForPrediction, TimesFMModel, - T5Tokenizer, ) @@ -57,1295 +52,121 @@ class TimesFMModelTester: def __init__( self, parent, - vocab_size=99, - batch_size=13, - encoder_seq_length=7, - decoder_seq_length=7, - # For common tests - is_training=True, - use_attention_mask=True, - use_labels=True, - hidden_size=32, - num_hidden_layers=2, - num_attention_heads=4, - d_ff=37, - relative_attention_num_buckets=8, - dropout_rate=0.1, - initializer_factor=0.002, - eos_token_id=1, - pad_token_id=0, - decoder_start_token_id=0, - scope=None, - decoder_layers=None, + patch_len: int = 32, + context_len: int = 512, + horizon_len: int = 128, + freq_size: int = 3, + num_layers: int = 20, + model_dim: int = 1280, + head_dim: int = 80, + num_heads: int = 16, + dropout_rate: float = 0.1, + tolerance: float = 1e-6, + rms_norm_eps: float = 1e-6, + quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + pad_val: float = 1123581321.0, + use_positional_embedding: bool = True, + per_core_batch_size: int = 32, + initializer_factor: float = 1.0, + backend: str = "gpu", ): self.parent = parent - self.batch_size = batch_size - self.encoder_seq_length = encoder_seq_length - self.decoder_seq_length = decoder_seq_length - # For common tests - self.seq_length = self.decoder_seq_length - self.is_training = is_training - self.use_attention_mask = use_attention_mask - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.d_ff = d_ff - self.relative_attention_num_buckets = relative_attention_num_buckets + self.patch_len = patch_len + self.context_len = context_len + self.horizon_len = horizon_len + self.quantiles = quantiles + self.pad_val = pad_val + self.freq_size = freq_size + self.model_dim = model_dim + self.head_dim = head_dim + self.num_layers = num_layers + self.num_heads = num_heads self.dropout_rate = dropout_rate + self.tolerance = tolerance + self.rms_norm_eps = rms_norm_eps + self.use_positional_embedding = use_positional_embedding + self.per_core_batch_size = per_core_batch_size self.initializer_factor = initializer_factor - self.eos_token_id = eos_token_id - self.pad_token_id = pad_token_id - self.decoder_start_token_id = decoder_start_token_id - self.scope = None - self.decoder_layers = decoder_layers + self.backend = backend def get_large_model_config(self): - return TimesFMConfig.from_pretrained("google-timesfm/timesfm-base") - - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size).clamp(2) - input_ids[:, -1] = self.eos_token_id # Eos Token - decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) - - attention_mask = None - decoder_attention_mask = None - if self.use_attention_mask: - attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) - decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2) - - lm_labels = None - if self.use_labels: - lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) - - config = self.get_config() - - return ( - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ) - - def get_pipeline_config(self): - return TimesFMConfig( - vocab_size=166, # timesfm forces 100 extra tokens - d_model=self.hidden_size, - d_ff=self.d_ff, - d_kv=self.hidden_size // self.num_attention_heads, - num_layers=self.num_hidden_layers, - num_decoder_layers=self.decoder_layers, - num_heads=self.num_attention_heads, - relative_attention_num_buckets=self.relative_attention_num_buckets, - dropout_rate=self.dropout_rate, - initializer_factor=self.initializer_factor, - eos_token_id=self.eos_token_id, - bos_token_id=self.pad_token_id, - pad_token_id=self.pad_token_id, - decoder_start_token_id=self.decoder_start_token_id, - ) + return TimesFMConfig.from_pretrained("google/timesfm-1.0-200m-pytorch") def get_config(self): return TimesFMConfig( - vocab_size=self.vocab_size, - d_model=self.hidden_size, - d_ff=self.d_ff, - d_kv=self.hidden_size // self.num_attention_heads, - num_layers=self.num_hidden_layers, - num_decoder_layers=self.decoder_layers, - num_heads=self.num_attention_heads, - relative_attention_num_buckets=self.relative_attention_num_buckets, + patch_len=self.patch_len, + context_len=self.context_len, + horizon_len=self.horizon_len, + quantiles=self.quantiles, + pad_val=self.pad_val, + freq_size=self.freq_size, + model_dim=self.model_dim, + head_dim=self.head_dim, + num_layers=self.num_layers, + num_heads=self.num_heads, dropout_rate=self.dropout_rate, + tolerance=self.tolerance, + rms_norm_eps=self.rms_norm_eps, + use_positional_embedding=self.use_positional_embedding, + per_core_batch_size=self.per_core_batch_size, initializer_factor=self.initializer_factor, - eos_token_id=self.eos_token_id, - bos_token_id=self.pad_token_id, - pad_token_id=self.pad_token_id, - decoder_start_token_id=self.decoder_start_token_id, + backend=self.backend, ) - def check_prepare_lm_labels_via_shift_left( - self, - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ): - model = TimesFMModel(config=config) - model.to(torch_device) - model.eval() - - # make sure that lm_labels are correctly padded from the right - lm_labels.masked_fill_((lm_labels == self.decoder_start_token_id), self.eos_token_id) - - # add casaul pad token mask - triangular_mask = torch.tril(lm_labels.new_ones(lm_labels.shape)).logical_not() - lm_labels.masked_fill_(triangular_mask, self.pad_token_id) - decoder_input_ids = model._shift_right(lm_labels) - - for i, (decoder_input_ids_slice, lm_labels_slice) in enumerate(zip(decoder_input_ids, lm_labels)): - # first item - self.parent.assertEqual(decoder_input_ids_slice[0].item(), self.decoder_start_token_id) - if i < decoder_input_ids_slice.shape[-1]: - if i < decoder_input_ids.shape[-1] - 1: - # items before diagonal - self.parent.assertListEqual( - decoder_input_ids_slice[1 : i + 1].tolist(), lm_labels_slice[:i].tolist() - ) - # pad items after diagonal - if i < decoder_input_ids.shape[-1] - 2: - self.parent.assertListEqual( - decoder_input_ids_slice[i + 2 :].tolist(), lm_labels_slice[i + 1 : -1].tolist() - ) - else: - # all items after square - self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist()) - - def create_and_check_model( - self, - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ): - model = TimesFMModel(config=config) - model.to(torch_device) - model.eval() - result = model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - ) - result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) - decoder_output = result.last_hidden_state - decoder_past = result.past_key_values - encoder_output = result.encoder_last_hidden_state - - self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size)) - self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size)) - # There should be `num_layers` key value embeddings stored in decoder_past - self.parent.assertEqual(len(decoder_past), config.num_layers) - # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple - self.parent.assertEqual(len(decoder_past[0]), 4) - - def create_and_check_with_lm_head( - self, - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ): - model = TimesFMForPrediction(config=config).to(torch_device).eval() - outputs = model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - labels=lm_labels, - ) - self.parent.assertEqual(len(outputs), 4) - self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size)) - self.parent.assertEqual(outputs["loss"].size(), ()) - - def create_and_check_decoder_model_past( - self, - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ): - model = TimesFMModel(config=config).get_decoder().to(torch_device).eval() - # first forward pass - outputs = model(input_ids, use_cache=True) - outputs_use_cache_conf = model(input_ids) - outputs_no_past = model(input_ids, use_cache=False) - - self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) - self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) - - output, past_key_values = outputs.to_tuple() - - # create hypothetical next token and extent to next_input_ids - next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) - - # append to next input_ids and - next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) - - output_from_no_past = model(next_input_ids)["last_hidden_state"] - output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] - - # select random slice - random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() - output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() - output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() - - # test that outputs are equal for slice - self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) - - def create_and_check_decoder_model_attention_mask_past( - self, - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ): - model = TimesFMModel(config=config).get_decoder() - model.to(torch_device) - model.eval() - - # create attention mask - attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) - - half_seq_length = input_ids.shape[-1] // 2 - attn_mask[:, half_seq_length:] = 0 - - # first forward pass - output, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True).to_tuple() - - # create hypothetical next token and extent to next_input_ids - next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) - - # change a random masked slice from input_ids - random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1 - random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) - input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens - - # append to next input_ids and attn_mask - next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) - attn_mask = torch.cat( - [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], - dim=1, - ) - - # get two different outputs - output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] - output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[ - "last_hidden_state" - ] - - # select random slice - random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() - output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() - output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() - - # test that outputs are equal for slice - self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) - - def create_and_check_decoder_model_past_large_inputs( - self, - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ): - model = TimesFMModel(config=config).get_decoder().to(torch_device).eval() - # first forward pass - outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) - - output, past_key_values = outputs.to_tuple() - - # create hypothetical multiple next token and extent to next_input_ids - next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) - next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) - - # append to next input_ids and - next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) - next_attention_mask = torch.cat([attention_mask, next_mask], dim=-1) + def get_pipeline_config(self): + return self.get_config() - output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"] - output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[ - "last_hidden_state" + def prepare_config_and_inputs(self): + forecast_input = [ + np.sin(np.linspace(0, 20, 100)), + np.sin(np.linspace(0, 20, 200)), + np.sin(np.linspace(0, 20, 400)), ] + frequency_input = [0, 1, 2] - # select random slice - random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() - output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() - output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() - - self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) - - # test that outputs are equal for slice - self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + config = self.get_config() - def create_and_check_generate_with_past_key_values( - self, - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ): - model = TimesFMForPrediction(config=config).to(torch_device).eval() - torch.manual_seed(0) - output_without_past_cache = model.generate( - input_ids[:1], num_beams=2, max_length=5, do_sample=True, use_cache=False + return ( + config, + forecast_input, + frequency_input, ) - torch.manual_seed(0) - output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True) - self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache)) - - def create_and_check_model_fp16_forward( - self, - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ): - model = TimesFMModel(config=config).to(torch_device).half().eval() - output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"] - self.parent.assertFalse(torch.isnan(output).any().item()) - - def create_and_check_encoder_decoder_shared_weights( - self, - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ): - for model_class in [TimesFMModel, TimesFMForPrediction]: - torch.manual_seed(0) - model = model_class(config=config).to(torch_device).eval() - # load state dict copies weights but does not tie them - model.encoder.load_state_dict(model.decoder.state_dict(), strict=False) - - torch.manual_seed(0) - tied_config = copy.deepcopy(config) - tied_config.tie_encoder_decoder = True - tied_model = model_class(config=tied_config).to(torch_device).eval() - - model_result = model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - ) - - tied_model_result = tied_model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - ) - - # check that models has less parameters - self.parent.assertLess( - sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) - ) - random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() - - # check that outputs are equal - self.parent.assertTrue( - torch.allclose( - model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4 - ) - ) - - # check that outputs after saving and loading are equal - with tempfile.TemporaryDirectory() as tmpdirname: - tied_model.save_pretrained(tmpdirname) - tied_model = model_class.from_pretrained(tmpdirname) - tied_model.to(torch_device) - tied_model.eval() - - # check that models has less parameters - self.parent.assertLess( - sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) - ) - random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() - - tied_model_result = tied_model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - ) - - # check that outputs are equal - self.parent.assertTrue( - torch.allclose( - model_result[0][0, :, random_slice_idx], - tied_model_result[0][0, :, random_slice_idx], - atol=1e-4, - ) - ) - - def check_resize_embeddings_timesfm_v1_1( - self, - config, - ): - prev_vocab_size = config.vocab_size - - config.tie_word_embeddings = False - model = TimesFMForPrediction(config=config).to(torch_device).eval() - model.resize_token_embeddings(prev_vocab_size - 10) - - self.parent.assertEqual(model.get_input_embeddings().weight.shape[0], prev_vocab_size - 10) - self.parent.assertEqual(model.get_output_embeddings().weight.shape[0], prev_vocab_size - 10) - self.parent.assertEqual(model.config.vocab_size, prev_vocab_size - 10) def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() ( config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ) = config_and_inputs + forecast_input, + frequency_input, + ) = self.prepare_config_and_inputs() inputs_dict = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "decoder_input_ids": decoder_input_ids, - "decoder_attention_mask": decoder_attention_mask, - "use_cache": False, + "inputs": forecast_input, + "freq": frequency_input, } return config, inputs_dict @require_torch -class TimesFMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = ( - (TimesFMModel, TimesFMForPrediction) - if is_torch_available() - else () - ) - all_generative_model_classes = (TimesFMForPrediction,) if is_torch_available() else () - all_parallelizable_model_classes = (TimesFMModel, TimesFMForPrediction) if is_torch_available() else () +class TimesFMModelTest( + ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase +): + all_model_classes = (TimesFMModel,) if is_torch_available() else () + all_generative_model_classes = (TimesFMModel,) if is_torch_available() else () + all_parallelizable_model_classes = () fx_compatible = False test_pruning = False - test_resize_embeddings = True - test_model_parallel = True - is_encoder_decoder = True - # The small TimesFM model needs higher percentages for CPU/MP tests - model_split_percents = [0.5, 0.8, 0.9] + test_resize_embeddings = False + test_model_parallel = False + is_encoder_decoder = False def setUp(self): self.model_tester = TimesFMModelTester(self) - self.config_tester = ConfigTester(self, config_class=TimesFMConfig, d_model=37) + self.config_tester = ConfigTester(self, config_class=TimesFMConfig) - # TimesFMForSequenceClassification does not support inputs_embeds - def test_inputs_embeds(self): + def test_create_and_run_model(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in (TimesFMModel, TimesFMForPrediction): - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) - - if not self.is_encoder_decoder: - input_ids = inputs["input_ids"] - del inputs["input_ids"] - else: - encoder_input_ids = inputs["input_ids"] - decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) - del inputs["input_ids"] - inputs.pop("decoder_input_ids", None) - - wte = model.get_input_embeddings() - if not self.is_encoder_decoder: - inputs["inputs_embeds"] = wte(input_ids) - else: - inputs["inputs_embeds"] = wte(encoder_input_ids) - inputs["decoder_inputs_embeds"] = wte(decoder_input_ids) - - with torch.no_grad(): - model(**inputs)[0] - - def test_config_and_model_silu_gated(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - config = config_and_inputs[0] - config.feed_forward_proj = "gated-silu" - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_with_lm_head(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_with_lm_head(*config_and_inputs) - - def test_decoder_model_past(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_decoder_model_past(*config_and_inputs) - - def test_decoder_model_past_with_attn_mask(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) - - def test_decoder_model_past_with_3d_attn_mask(self): - ( - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ) = self.model_tester.prepare_config_and_inputs() - - attention_mask = ids_tensor( - [self.model_tester.batch_size, self.model_tester.encoder_seq_length, self.model_tester.encoder_seq_length], - vocab_size=2, - ) - decoder_attention_mask = ids_tensor( - [self.model_tester.batch_size, self.model_tester.decoder_seq_length, self.model_tester.decoder_seq_length], - vocab_size=2, - ) - - self.model_tester.create_and_check_decoder_model_attention_mask_past( - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ) - - def test_decoder_model_past_with_large_inputs(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) - - def test_generate_with_past_key_values(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_generate_with_past_key_values(*config_and_inputs) - - def test_encoder_decoder_shared_weights(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_encoder_decoder_shared_weights(*config_and_inputs) - - @unittest.skipIf(torch_device == "cpu", "Cant do half precision") - def test_model_fp16_forward(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) - - def test_v1_1_resize_embeddings(self): - config = self.model_tester.prepare_config_and_inputs()[0] - self.model_tester.check_resize_embeddings_timesfm_v1_1(config) - - @slow - def test_model_from_pretrained(self): - model_name = "google/timesfm-1.0-200m" - model = TimesFMModel.from_pretrained(model_name) - self.assertIsNotNone(model) - - @unittest.skip(reason="Test has a segmentation fault on torch 1.8.0") - def test_export_to_onnx(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - model = TimesFMModel(config_and_inputs[0]).to(torch_device) - with tempfile.TemporaryDirectory() as tmpdirname: - torch.onnx.export( - model, - (config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]), - f"{tmpdirname}/timesfm_test.onnx", - export_params=True, - opset_version=9, - input_names=["input_ids", "decoder_input_ids"], - ) - - def test_generate_with_head_masking(self): - attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - config = config_and_inputs[0] - max_length = config_and_inputs[1].shape[-1] + 3 - model = TimesFMForPrediction(config).eval() + model = TimesFMModel(config) model.to(torch_device) - - head_masking = { - "head_mask": torch.zeros(config.num_layers, config.num_heads, device=torch_device), - "decoder_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device), - "cross_attn_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device), - } - - for attn_name, (name, mask) in zip(attention_names, head_masking.items()): - head_masks = {name: mask} - # Explicitly pass decoder_head_mask as it is required from TimesFM model when head_mask specified - if name == "head_mask": - head_masks["decoder_head_mask"] = torch.ones( - config.num_decoder_layers, config.num_heads, device=torch_device - ) - - out = model.generate( - config_and_inputs[1], - num_beams=1, - max_length=max_length, - output_attentions=True, - return_dict_in_generate=True, - **head_masks, - ) - # We check the state of decoder_attentions and cross_attentions just from the last step - attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1] - self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0) - - -class TimesFMEncoderOnlyModelTester: - def __init__( - self, - parent, - vocab_size=99, - batch_size=13, - encoder_seq_length=7, - # For common tests - use_attention_mask=True, - hidden_size=32, - num_hidden_layers=2, - num_attention_heads=4, - d_ff=37, - relative_attention_num_buckets=8, - is_training=False, - dropout_rate=0.1, - initializer_factor=0.002, - is_encoder_decoder=False, - eos_token_id=1, - pad_token_id=0, - scope=None, - ): - self.parent = parent - self.batch_size = batch_size - self.encoder_seq_length = encoder_seq_length - # For common tests - self.seq_length = self.encoder_seq_length - self.use_attention_mask = use_attention_mask - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.d_ff = d_ff - self.relative_attention_num_buckets = relative_attention_num_buckets - self.dropout_rate = dropout_rate - self.initializer_factor = initializer_factor - self.eos_token_id = eos_token_id - self.pad_token_id = pad_token_id - self.is_encoder_decoder = is_encoder_decoder - self.scope = None - self.is_training = is_training - - def get_large_model_config(self): - return TimesFMConfig.from_pretrained("google-timesfm/timesfm-base") - - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) - - attention_mask = None - if self.use_attention_mask: - attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) - - config = TimesFMConfig( - vocab_size=self.vocab_size, - d_model=self.hidden_size, - d_ff=self.d_ff, - d_kv=self.hidden_size // self.num_attention_heads, - num_layers=self.num_hidden_layers, - num_heads=self.num_attention_heads, - relative_attention_num_buckets=self.relative_attention_num_buckets, - dropout_rate=self.dropout_rate, - initializer_factor=self.initializer_factor, - eos_token_id=self.eos_token_id, - bos_token_id=self.pad_token_id, - pad_token_id=self.pad_token_id, - is_encoder_decoder=self.is_encoder_decoder, - ) - - return ( - config, - input_ids, - attention_mask, - ) - - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - attention_mask, - ) = config_and_inputs - - inputs_dict = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - return config, inputs_dict - - -def use_task_specific_params(model, task): - model.config.update(model.config.task_specific_params[task]) - - -@require_torch -@require_accelerate -@require_tokenizers -@slow -class TimesFMModelFp16Tests(unittest.TestCase): - def test_fp16_fp32_conversion(self): - r""" - A test to check whether the argument `keep_in_fp32_modules` correctly does its job - """ - orig_import = __import__ - accelerate_mock = unittest.mock.Mock() - - # mock import of accelerate - def import_accelerate_mock(name, *args, **kwargs): - if name == "accelerate": - if accelerate_available: - return accelerate_mock - else: - raise ImportError - return orig_import(name, *args, **kwargs) - - # Load without using `accelerate` - with unittest.mock.patch("builtins.__import__", side_effect=import_accelerate_mock): - accelerate_available = False - - model = TimesFMForPrediction.from_pretrained("google/timesfm-1.0-200m", torch_dtype=torch.float16) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16) - - # Load without in bf16 - model = TimesFMForPrediction.from_pretrained("google/timesfm-1.0-200m", torch_dtype=torch.bfloat16) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) - - # Load using `accelerate` in bf16 - model = TimesFMForPrediction.from_pretrained( - "google/timesfm-1.0-200m", torch_dtype=torch.bfloat16, device_map="auto" - ) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) - - # Load using `accelerate` in bf16 - model = TimesFMForPrediction.from_pretrained( - "google/timesfm-1.0-200m", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True - ) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) - - # Load without using `accelerate` - model = TimesFMForPrediction.from_pretrained( - "google/timesfm-1.0-200m", torch_dtype=torch.float16, low_cpu_mem_usage=True - ) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16) - - # Load using `accelerate` - model = TimesFMForPrediction.from_pretrained( - "google/timesfm-1.0-200m", torch_dtype=torch.float16, device_map="auto" - ) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16) - - -@require_torch -@require_sentencepiece -@require_tokenizers -class TimesFMModelIntegrationTests(unittest.TestCase): - @cached_property - def model(self): - return TimesFMForPrediction.from_pretrained("google-timesfm/timesfm-base").to(torch_device) - - @cached_property - def tokenizer(self): - return T5Tokenizer.from_pretrained("google-timesfm/timesfm-base") - - @slow - def test_torch_quant(self): - r""" - Test that a simple `torch.quantization.quantize_dynamic` call works on a TimesFM model. - """ - model_name = "google/flan-timesfm-small" - tokenizer = T5Tokenizer.from_pretrained(model_name) - model = TimesFMForPrediction.from_pretrained(model_name) - model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) - input_text = "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?" - input_ids = tokenizer(input_text, return_tensors="pt").input_ids - _ = model.generate(input_ids) - - @slow - def test_small_generation(self): - model = TimesFMForPrediction.from_pretrained("google/timesfm-1.0-200m").to(torch_device) - model.config.max_length = 8 - model.config.num_beams = 1 - model.config.do_sample = False - tokenizer = T5Tokenizer.from_pretrained("google/timesfm-1.0-200m") - - input_ids = tokenizer("summarize: Hello there", return_tensors="pt").input_ids.to(torch_device) - - sequences = model.generate(input_ids) - - output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] - self.assertTrue(output_str == "Hello there!") - - @slow - def test_small_integration_test(self): - """ - For comparision run: - >>> import timesfm # pip install timesfm==0.7.1 - >>> from timesfm.data.sentencepiece_vocabulary import SentencePieceVocabulary - - >>> path_to_mtf_small_timesfm_checkpoint = '' - >>> path_to_mtf_small_spm_model_path = '' - >>> timesfm_model = timesfm.models.MtfModel(model_dir=path_to_mtf_small_timesfm_checkpoint, batch_size=1, tpu=None) - >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) - >>> score = timesfm_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) - """ - - model = TimesFMForPrediction.from_pretrained("google/timesfm-1.0-200m").to(torch_device) - tokenizer = T5Tokenizer.from_pretrained("google/timesfm-1.0-200m") - - input_ids = tokenizer("Hello there", return_tensors="pt").input_ids - labels = tokenizer("Hi I am", return_tensors="pt").input_ids - - loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss - mtf_score = -(labels.shape[-1] * loss.item()) - - EXPECTED_SCORE = -19.0845 - self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) - - @slow - def test_small_v1_1_integration_test(self): - """ - For comparision run: - >>> import timesfm # pip install timesfm==0.7.1 - >>> from timesfm.data.sentencepiece_vocabulary import SentencePieceVocabulary - - >>> path_to_mtf_small_timesfm_v1_1_checkpoint = '' - >>> path_to_mtf_small_spm_model_path = '' - >>> timesfm_model = timesfm.models.MtfModel(model_dir=path_to_mtf_small_timesfm_v1_1_checkpoint, batch_size=1, tpu=None) - >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) - >>> score = timesfm_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) - """ - - model = TimesFMForPrediction.from_pretrained("google/timesfm-v1_1-small").to(torch_device) - tokenizer = T5Tokenizer.from_pretrained("google/timesfm-v1_1-small") - - input_ids = tokenizer("Hello there", return_tensors="pt").input_ids - labels = tokenizer("Hi I am", return_tensors="pt").input_ids - - loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss - mtf_score = -(labels.shape[-1] * loss.item()) - - EXPECTED_SCORE = -59.0293 - self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) - - @slow - def test_small_bytimesfm_integration_test(self): - """ - For comparision run: - >>> import timesfm # pip install timesfm==0.9.1 - - >>> path_to_bytimesfm_small_checkpoint = '' - >>> timesfm_model = timesfm.models.MtfModel(model_dir=path_to_tf_checkpoint, batch_size=1, tpu=None) - >>> vocab = timesfm.data.ByteVocabulary() - >>> score = timesfm_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) - """ - - model = TimesFMForPrediction.from_pretrained("google/bytimesfm-small").to(torch_device) - tokenizer = ByT5Tokenizer.from_pretrained("google/bytimesfm-small") - - input_ids = tokenizer("Hello there", return_tensors="pt").input_ids - labels = tokenizer("Hi I am", return_tensors="pt").input_ids - - loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss - mtf_score = -(labels.shape[-1] * loss.item()) - - EXPECTED_SCORE = -60.7397 - self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) - - @slow - def test_summarization(self): - model = self.model - tok = self.tokenizer - - FRANCE_ARTICLE = ( # @noqa - "Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings" - " Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane." - ' Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation."' - ' He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s' - " comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video" - " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French" - " Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a" - " phone at the wreckage site. The two publications described the supposed video, but did not post it on" - " their websites. The publications said that they watched the video, which was found by a source close to" - " the investigation. \"One can hear cries of 'My God' in several languages,\" Paris Match reported." - ' "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the' - " cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the" - ' screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt,' - " editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said" - " the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman" - " in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the" - ' reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said,' - ' but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be' - " sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by" - " specialized technicians working hand-in-hand with investigators. But none of the cell phones found so" - " far have been sent to the institute, Menichini said. Asked whether staff involved in the search could" - ' have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin' - ' Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match' - ' are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered' - ' cell phones from the crash site after Bild and Paris Match published their reports. "That is something' - " we did not know before. ... Overall we can say many things of the investigation weren't revealed by the" - ' investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline' - " Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the" - " controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the" - ' French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of' - ' severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school' - " discovered in an internal investigation, Lufthansa said, included medical documents he submitted in" - " connection with resuming his flight training. The announcement indicates that Lufthansa, the parent" - " company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and" - " ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100%" - ' fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was' - " sharing the information and documents -- including training and medical records -- with public" - " prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the" - " past week to recover human remains and plane debris scattered across a steep mountainside. He saw the" - " crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash" - " site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late" - " Tuesday that no visible human remains were left at the site but recovery teams would keep searching." - " French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all" - " the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested." - " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said." - " Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew" - " on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with" - " the flight school during his training were among several developments as investigators continued to" - " delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa" - " spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his" - ' examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in' - " Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at" - " some point before his aviation career and underwent psychotherapy before he got his pilot's license." - " Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the" - " crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to" - " lose his pilot's license, a European government official briefed on the investigation told CNN on" - ' Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being' - " considered. Another source, a law enforcement official briefed on the investigation, also told CNN that" - " authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would" - " not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had" - " seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded" - " he had psychological issues, the European government official said. But no matter what details emerge" - " about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic" - ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact' - " that maybe they weren't going to keep doing their job and they're upset about that and so they're" - ' suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to' - " also take that rage and turn it outward on 149 other people who had nothing to do with the person's" - ' problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight' - " 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura" - " Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine" - " Amiel and Anna-Maja Rappard contributed to this report." - ) - SHORTER_ARTICLE = ( - "(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on" - " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The" - " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based." - " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its" - ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East' - ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the' - " situation in Palestinian territories, paving the way for possible war crimes investigations against" - " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and" - " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the" - " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a" - ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the' - ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an' - ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge' - " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the" - ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine' - " acquires all the rights as well as responsibilities that come with being a State Party to the Statute." - ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights' - ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should' - " immediately end their pressure, and countries that support universal acceptance of the court's treaty" - ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the' - " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's" - ' decision to join a treaty to which over 100 countries around the world are members." In January, when' - " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an" - ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"' - " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a" - ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in' - ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We' - ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"' - " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the" - ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the' - " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou" - ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war' - " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry" - " will include alleged war crimes committed since June. The International Criminal Court was set up in" - " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder" - " and Faith Karimi contributed to this report." - ) - IRAN_ARTICLE = ( - "(CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran" - " in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively" - " block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger." - " Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli" - " Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a" - " letter to the Iranian leadership warning them away from a deal. The debate that has already begun since" - " the announcement of the new framework will likely result in more heat than light. It will not be helped" - " by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: ." - " The most misleading assertion, despite universal rejection by experts, is that the negotiations'" - " objective at the outset was the total elimination of any nuclear program in Iran. That is the position" - " of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it" - " had been, there would have been no Iranian team at the negotiating table. Rather, the objective has" - " always been to structure an agreement or series of agreements so that Iran could not covertly develop a" - " nuclear arsenal before the United States and its allies could respond. The new framework has exceeded" - " expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by" - " two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another" - " dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite" - " sharp accusations by some in the United States and its allies, Iran denies having such a program, and" - " U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's" - " continued cooperation with International Atomic Energy Agency inspections is further evidence on this" - " point, and we'll know even more about Iran's program in the coming months and years because of the deal." - " In fact, the inspections provisions that are part of this agreement are designed to protect against any" - " covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that" - " the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter" - " warning that a deal might be killed by Congress or a future president). This of course is not the case." - " The talks were between Iran and the five permanent members of the U.N. Security Council (United States," - " United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has" - " played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement" - " reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran" - " and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement" - " contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the" - " case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased" - " or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes" - " Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear" - " sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going" - " forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such" - " a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the" - ' agreement should be a formal treaty requiring the Senate to "advise and consent." But the issue is not' - " suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New" - " START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement" - " with Iran will not be so balanced. The restrictions and obligations in the final framework agreement" - " will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove" - " most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally" - " some insist that any agreement must address Iranian missile programs, human rights violations or support" - " for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are" - " unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in" - " the negotiations would be a poison pill. This agreement should be judged on its merits and on how it" - " affects the security of our negotiating partners and allies, including Israel. Those judgments should be" - " fact-based, not based on questionable assertions or dubious assumptions." - ) - ARTICLE_SUBWAY = ( - "New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A" - " year later, she got married again in Westchester County, but to a different man and without divorcing" - " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos" - ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married' - " once more, this time in the Bronx. In an application for a marriage license, she stated it was her" - ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false' - ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage' - " license application, according to court documents. Prosecutors said the marriages were part of an" - " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to" - " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was" - " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New" - " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total," - " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All" - " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be" - " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors" - " said the immigration scam involved some of her husbands, who filed for permanent residence status" - " shortly after the marriages. Any divorces happened only after such filings were approved. It was" - " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District" - " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's" - ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,' - " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his" - " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces" - " up to four years in prison. Her next court appearance is scheduled for May 18." - ) - - expected_summaries = [ - 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a' - " cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one" - " magazine says .", - "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a" - " preliminary examination into the situation in the occupied Palestinian territory . as members of the" - " court, Palestinians may be subject to counter-charges as well .", - "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller:" - " the debate that has already begun since the announcement of the new framework will likely result in more" - " heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and" - " implement a rigorous inspection regime .", - "prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two" - ' criminal counts of "offering a false instrument for filing in the first degree" she has been married 10' - " times, with nine of her marriages occurring between 1999 and 2002 .", - ] - - use_task_specific_params(model, "summarization") - - dct = tok( - [model.config.prefix + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]], - padding="max_length", - truncation=True, - return_tensors="pt", - ).to(torch_device) - self.assertEqual(512, dct["input_ids"].shape[1]) - - hypotheses_batch = model.generate( - **dct, - num_beams=4, - length_penalty=2.0, - max_length=142, - min_length=56, - no_repeat_ngram_size=3, - do_sample=False, - early_stopping=True, - ) - - decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False) - self.assertListEqual( - expected_summaries, - decoded, - ) - - @slow - def test_translation_en_to_de(self): - model = self.model - tok = self.tokenizer - use_task_specific_params(model, "translation_en_to_de") - - en_text = '"Luigi often said to me that he never wanted the brothers to end up in court", she wrote.' - expected_translation = ( - '"Luigi sagte mir oft, dass er nie wollte, dass die Brüder am Gericht sitzen", schrieb sie.' - ) - - input_ids = tok.encode(model.config.prefix + en_text, return_tensors="pt") - input_ids = input_ids.to(torch_device) - output = model.generate(input_ids) - translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) - self.assertEqual(translation, expected_translation) - - @slow - def test_translation_en_to_fr(self): - model = self.model # google-timesfm/timesfm-base - tok = self.tokenizer - use_task_specific_params(model, "translation_en_to_fr") - - en_text = ( - ' This image section from an infrared recording by the Spitzer telescope shows a "family portrait" of' - " countless generations of stars: the oldest stars are seen as blue dots. " - ) - - input_ids = tok.encode(model.config.prefix + en_text, return_tensors="pt") - input_ids = input_ids.to(torch_device) - - output = model.generate( - input_ids=input_ids, - num_beams=4, - length_penalty=2.0, - max_length=100, - no_repeat_ngram_size=3, - do_sample=False, - early_stopping=True, - ) - translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) - new_truncated_translation = ( - "Cette section d'images provenant de l'enregistrement infrarouge effectué par le télescope Spitzer montre " - "un " - "« portrait familial » de générations innombrables d’étoiles : les plus anciennes sont observées " - "sous forme " - "de points bleus." - ) - - self.assertEqual(translation, new_truncated_translation) - - @slow - def test_translation_en_to_ro(self): - model = self.model - tok = self.tokenizer - use_task_specific_params(model, "translation_en_to_ro") - en_text = "Taco Bell said it plans to add 2,000 locations in the US by 2022." - expected_translation = "Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022." - - inputs = tok(model.config.prefix + en_text, return_tensors="pt").to(torch_device) - output = model.generate(**inputs) - translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) - self.assertEqual(translation, expected_translation) - - @slow - def test_contrastive_search_timesfm(self): - article = ( - " New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A" - " year later, she got married again in Westchester County, but to a different man and without divorcing" - " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos" - ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married' - " once more, this time in the Bronx. In an application for a marriage license, she stated it was her" - ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false' - ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage' - " license application, according to court documents. Prosecutors said the marriages were part of an" - " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to" - " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was" - " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New" - " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total," - " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All" - " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be" - " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors" - " said the immigration scam involved some of her husbands, who filed for permanent residence status" - " shortly after the marriages. Any divorces happened only after such filings were approved. It was" - " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District" - " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's" - ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,' - " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his" - " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces" - " up to four years in prison. Her next court appearance is scheduled for May 18." - ) - article = "summarize: " + article.strip() - timesfm_tokenizer = AutoTokenizer.from_pretrained("flax-community/timesfm-base-cnn-dm") - timesfm_model = TimesFMForPrediction.from_pretrained("flax-community/timesfm-base-cnn-dm").to(torch_device) - input_ids = timesfm_tokenizer( - article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt" - ).input_ids.to(torch_device) - - outputs = timesfm_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64) - generated_text = timesfm_tokenizer.batch_decode(outputs, skip_special_tokens=True) - - self.assertListEqual( - generated_text, - [ - "Liana Barrientos has been married 10 times, nine of them in the Bronx. Her husbands filed for " - "permanent residence after the marriages, prosecutors say." - ], - ) - - -@require_torch -class TestAsymmetricTimesFM(unittest.TestCase): - def build_model_and_check_forward_pass(self, **kwargs): - tester = TimesFMModelTester(self, **kwargs) - config, *inputs = tester.prepare_config_and_inputs() - ( - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ) = inputs - model = TimesFMForPrediction(config=config).to(torch_device).eval() - outputs = model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - labels=lm_labels, - ) - # outputs = model(*inputs) - assert len(outputs) == 4 - assert outputs["logits"].size() == (tester.batch_size, tester.decoder_seq_length, tester.vocab_size) - assert outputs["loss"].size() == () - return model - - def test_small_decoder(self): - # num_hidden_layers is passed to TimesFMConfig as num_layers - model = self.build_model_and_check_forward_pass(decoder_layers=1, num_hidden_layers=2) - assert len(model.encoder.block) == 2 - assert len(model.decoder.block) == 1 - - def test_defaulting_to_symmetry(self): - # num_hidden_layers is passed to TimesFMConfig as num_layers - model = self.build_model_and_check_forward_pass(num_hidden_layers=2) - assert len(model.decoder.block) == len(model.encoder.block) == 2 + model.eval() + results = model.run_model(**inputs_dict) + assert results From 72ffaafa15bd99c564988fbbc842a1105724441b Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Thu, 10 Oct 2024 17:07:48 -0700 Subject: [PATCH 099/242] the model runs --- src/transformers/models/auto/modeling_auto.py | 4 +-- src/transformers/models/timesfm/__init__.py | 2 +- .../models/timesfm/configuration_timesfm.py | 2 -- .../models/timesfm/modeling_timesfm.py | 30 ++++++++----------- .../models/timesfm/timesfm_layers.py | 3 +- tests/models/timesfm/test_modeling_timesfm.py | 3 -- 6 files changed, 16 insertions(+), 28 deletions(-) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index b5371967eb02..d6fef140ed71 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -380,7 +380,7 @@ ("squeezebert", "SqueezeBertForMaskedLM"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), - ("timesfm", "TimesFMForPrediction"), + ("timesfm", "TimesFMModel"), ("tapas", "TapasForMaskedLM"), ("transfo-xl", "TransfoXLLMHeadModel"), ("tvlt", "TvltForPreTraining"), @@ -475,7 +475,6 @@ ("squeezebert", "SqueezeBertForMaskedLM"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), - ("timesfm", "TimesFMForPrediction"), ("tapas", "TapasForMaskedLM"), ("transfo-xl", "TransfoXLLMHeadModel"), ("wav2vec2", "Wav2Vec2ForMaskedLM"), @@ -963,7 +962,6 @@ ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToText"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), - ("timesfm", "TimesFMForPrediction"), ("umt5", "UMT5ForConditionalGeneration"), ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"), ] diff --git a/src/transformers/models/timesfm/__init__.py b/src/transformers/models/timesfm/__init__.py index 91f4693ae2e5..82bbb6be22ce 100644 --- a/src/transformers/models/timesfm/__init__.py +++ b/src/transformers/models/timesfm/__init__.py @@ -29,7 +29,7 @@ except OptionalDependencyNotAvailable: pass else: - _import_structure["timesfm"] = [ + _import_structure["modeling_timesfm"] = [ "TimesFMModel", "TimesFMPreTrainedModel", ] diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 29948593aff5..69397214c7ce 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -101,7 +101,6 @@ def __init__( use_positional_embedding: bool = True, per_core_batch_size: int = 32, initializer_factor: float = 1.0, - backend: str = "gpu", **kwargs, ): self.patch_len = patch_len @@ -120,7 +119,6 @@ def __init__( self.use_positional_embedding = use_positional_embedding self.per_core_batch_size = per_core_batch_size self.initializer_factor = initializer_factor - self.backend = backend super().__init__( **kwargs, diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 612493e7f7dd..ba1c2da9c1c8 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -293,7 +293,6 @@ def __init__(self, config: TimesFMConfig): self.output_patch_len = config.horizon_len self.num_layers = config.num_layers self.model_dims = config.model_dim - self.backend = config.backend self.quantiles = config.quantiles self.num_heads = config.num_heads @@ -301,7 +300,6 @@ def __init__(self, config: TimesFMConfig): self.per_core_batch_size = config.per_core_batch_size self.global_batch_size = config.per_core_batch_size * self.num_cores self._horizon_start = self.context_len - self.input_patch_len - self._device = config.backend def _preprocess( self, inputs: Sequence[np.array], freq: Sequence[int] @@ -429,7 +427,7 @@ def forward( ], dtype=np.float32, ) - ).to(self._device) + ) input_padding_in = torch.from_numpy( np.array( input_padding[ @@ -439,22 +437,18 @@ def forward( ], dtype=np.float32, ) - ).to(self._device) - inp_freq_in = ( - torch.from_numpy( - np.array( - inp_freq[ - i - * self.global_batch_size : (i + 1) - * self.global_batch_size, - :, - ], - dtype=np.int32, - ) - ) - .long() - .to(self._device) ) + inp_freq_in = torch.from_numpy( + np.array( + inp_freq[ + i + * self.global_batch_size : (i + 1) + * self.global_batch_size, + :, + ], + dtype=np.int32, + ) + ).long() mean_output, full_output = self.decoder.decode( input_ts=input_ts_in, paddings=input_padding_in, diff --git a/src/transformers/models/timesfm/timesfm_layers.py b/src/transformers/models/timesfm/timesfm_layers.py index 713623eb98a7..0ba6f2c6f54d 100644 --- a/src/transformers/models/timesfm/timesfm_layers.py +++ b/src/transformers/models/timesfm/timesfm_layers.py @@ -56,8 +56,9 @@ def _get_patch_index(arr: torch.Tensor): mask = 1 - pad # Calculate the number of valid elements + num_valid_elements = torch.sum(mask, dim=1) num_valid_elements = torch.where( - torch.sum(mask, dim=1) == 0, + num_valid_elements == 0, torch.tensor( 1, dtype=num_valid_elements.dtype, device=num_valid_elements.device ), diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index 2aafe4133c8d..ebf5df4d2cf3 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -68,7 +68,6 @@ def __init__( use_positional_embedding: bool = True, per_core_batch_size: int = 32, initializer_factor: float = 1.0, - backend: str = "gpu", ): self.parent = parent self.patch_len = patch_len @@ -87,7 +86,6 @@ def __init__( self.use_positional_embedding = use_positional_embedding self.per_core_batch_size = per_core_batch_size self.initializer_factor = initializer_factor - self.backend = backend def get_large_model_config(self): return TimesFMConfig.from_pretrained("google/timesfm-1.0-200m-pytorch") @@ -110,7 +108,6 @@ def get_config(self): use_positional_embedding=self.use_positional_embedding, per_core_batch_size=self.per_core_batch_size, initializer_factor=self.initializer_factor, - backend=self.backend, ) def get_pipeline_config(self): From 3818ee44065751e36df71203db5be23c8f354603 Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Thu, 24 Oct 2024 09:50:18 -0700 Subject: [PATCH 100/242] fixing unit tests --- .../models/timesfm/configuration_timesfm.py | 23 ++-- .../models/timesfm/modeling_timesfm.py | 101 ++++++++++++-- .../models/timesfm/timesfm_layers.py | 20 ++- tests/models/timesfm/test_modeling_timesfm.py | 124 +++++++++++++++--- 4 files changed, 224 insertions(+), 44 deletions(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 69397214c7ce..6e6dc8aec307 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -37,30 +37,28 @@ class TimesFMConfig(PretrainedConfig): Arguments: patch_len (`int`, *optional*, defaults to 32): The length of one patch in the input sequence. - horizon_len (`int`, *optional*, defaults to 128): - The length of the prediction horizon. context_len (`int`, *optional*, defaults to 512): The length of the input context. + horizon_len (`int`, *optional*, defaults to 128): + The length of the prediction horizon. freq_size (`int`, *optional*, defaults to 3): The number of frequency embeddings. + num_layers (`int`, *optional*, defaults to 20): + Number of Transformer layers. model_dim (`int`, *optional*, defaults to 1280): Size of the hidden layers in the feed-forward networks. head_dim (`int`, *optional*, defaults to 80): Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will be defined as `num_heads * head_dim`. - num_layers (`int`, *optional*, defaults to 20): - Number of Transformer layers. num_heads (`int`, *optional*, defaults to 16): Number of attention heads for each attention layer in the Transformer encoder. - tolerance (`float`, *optional*, defaults to 1e-6): - The tolerance for the quantile loss. dropout_rate (`float`, *optional*, defaults to 0.1): The ratio for all dropout layers. - classifier_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for classifier. - rms_norm_eps (`float`, *optional*, defaults to 1e-6): + tolerance (`float`, *optional*, defaults to 1e-06): + The tolerance for the quantile loss. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the RMS normalization layers. - quantiles (`List[float]`, *optional*, defaults to `[0.1, 0.25, 0.5, 0.75, 0.9]`): + quantiles (`List[float]`, *optional*, defaults to `[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]`): The quantiles to predict. pad_val (`float`, *optional*, defaults to 1123581321.0): The value used to pad the predictions. @@ -68,11 +66,9 @@ class TimesFMConfig(PretrainedConfig): Whether to add positional embeddings. per_core_batch_size (`int`, *optional*, defaults to 32): The batch size per core for data parallelism. - initializer_factor (`float`, *optional*, defaults to 1): + initializer_factor (`float`, *optional*, defaults to 1.0): A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing). - backend (`str`, *optional*, defaults to `"gpu"`): - The backend to use for the model. Can be either `"gpu"` or `"cpu"`. """ model_type = "timesfm" @@ -82,6 +78,7 @@ class TimesFMConfig(PretrainedConfig): "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers", } + is_encoder_decoder = False def __init__( self, diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index ba1c2da9c1c8..8757647075f6 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -23,21 +23,31 @@ import logging +from dataclasses import dataclass from typing import Any, Sequence -import pandas as pd + import numpy as np import torch import torch.nn as nn + +from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from .configuration_timesfm import TimesFMConfig from .timesfm_layers import * +@dataclass +class TimesFMOutput(BaseModelOutput): + mean_predictions: np.ndarray = None + full_predictions: np.ndarray = None + + class TimesFMPreTrainedModel(PreTrainedModel): """handles the loading for all models.""" config_class = TimesFMConfig base_model_prefix = "timesfm" + main_input_name = "inputs" def _init_weights(self, module): if isinstance(module, nn.Embedding): @@ -153,7 +163,9 @@ def _preprocess_input( # B x N x D patched_inputs = patched_inputs * (1.0 - patched_pads) + print(">>> PatchedDecoder patched_inputs", patched_inputs.shape) concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1) + print(">>> PatchedDecoder concat_inputs", concat_inputs.shape) model_input = self.input_ff_layer(concat_inputs) # A patch should not be padded even if there is at least one zero. @@ -190,7 +202,10 @@ def forward( input_ts: torch.Tensor, input_padding: torch.LongTensor, freq: torch.Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, ) -> torch.Tensor: + print(">>> PatchedDecoder input_ts", input_ts.shape) num_outputs = len(self.config.quantiles) + 1 model_input, patched_padding, stats, _ = self._preprocess_input( input_ts=input_ts, @@ -198,10 +213,14 @@ def forward( ) f_emb = self.freq_emb(freq) # B x 1 x D model_input += f_emb - model_output = self.stacked_transformer(model_input, patched_padding) + + print(">>> PatchedDecoder model_input", model_input.shape) + model_output, all_attentions, all_hidden_states = self.stacked_transformer(model_input, patched_padding, output_attentions=output_attentions, output_hidden_states=output_hidden_states) + if output_hidden_states: + all_hidden_states = [model_input] + all_hidden_states output_ts = self._postprocess_output(model_output, num_outputs, stats) - return output_ts + return output_ts, all_attentions, all_hidden_states def decode( self, @@ -212,7 +231,9 @@ def decode( output_patch_len: int | None = None, max_len: int = 512, return_forecast_on_context: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor]: + output_attentions: bool = False, + output_hidden_states: bool = False, + ): """Auto-regressive decoding without caching. Args: @@ -249,7 +270,7 @@ def decode( current_padding = paddings[:, 0 : final_out.shape[1]] input_ts = final_out[:, -max_len:] input_padding = current_padding[:, -max_len:] - fprop_outputs = self(input_ts, input_padding, freq) + fprop_outputs, all_attentions, all_hidden_states = self.forward(input_ts, input_padding, freq, output_attentions=output_attentions, output_hidden_states=output_hidden_states) if return_forecast_on_context and step_index == 0: # For the first decodings step, collect the model forecast on the # context except the unavailable first input batch forecast. @@ -276,7 +297,7 @@ def decode( # `full_outputs` indexing starts at the forecast horizon. full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0:horizon_len, :] - return (full_outputs[:, :, 0], full_outputs) + return full_outputs[:, :, 0], full_outputs, all_attentions, all_hidden_states class TimesFMModel(TimesFMPreTrainedModel): @@ -321,7 +342,7 @@ def _preprocess( - the number of padded examples for SPMD so that each core has the same number (a multiple of `batch_size`) of examples. """ - + print(">>> TimesFMModel _preprocess", len(inputs), inputs[0].shape) input_ts, input_padding, inp_freq = [], [], [] pmap_pad = ( @@ -353,6 +374,8 @@ def _preprocess( input_padding.append(input_padding[-1]) inp_freq.append(inp_freq[-1]) + print(">>> TimesFMModel input_ts", len(input_ts), input_ts[0].shape) + return ( np.stack(input_ts, axis=0), np.stack(input_padding, axis=0), @@ -368,6 +391,9 @@ def forward( forecast_context_len: int | None = None, return_forecast_on_context: bool = False, truncate_negative: bool = False, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ) -> tuple[np.ndarray, np.ndarray]: """Forecasts on a list of time series. @@ -394,12 +420,15 @@ def forward( Raises: ValueError: If the checkpoint is not properly loaded. """ + if return_dict is None: + return_dict = self.config.use_return_dict if forecast_context_len is None: fcontext_len = self.context_len else: fcontext_len = forecast_context_len inputs = [np.array(ts)[-fcontext_len:] for ts in inputs] + print(">>> TimesFMModel forward", len(inputs), inputs[0].shape) inp_min = np.min([np.min(ts) for ts in inputs]) if window_size is not None: @@ -412,10 +441,18 @@ def forward( logging.info("No frequency provided via `freq`. Default to high (0).") freq = [0] * len(inputs) + if output_attentions is None: + output_attentions = self.config.output_attentions + if output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states + input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq) + print(">>> TimesFMModel input_ts", input_ts.shape) with torch.no_grad(): mean_outputs = [] full_outputs = [] + all_attentions = [] + all_hidden_states = [] assert input_ts.shape[0] % self.global_batch_size == 0 for i in range(input_ts.shape[0] // self.global_batch_size): input_ts_in = torch.from_numpy( @@ -449,12 +486,14 @@ def forward( dtype=np.int32, ) ).long() - mean_output, full_output = self.decoder.decode( + mean_output, full_output, attentions, hidden_states = self.decoder.decode( input_ts=input_ts_in, paddings=input_padding_in, freq=inp_freq_in, horizon_len=self.horizon_len, return_forecast_on_context=return_forecast_on_context, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, ) mean_output = mean_output.detach().cpu().numpy() full_output = full_output.detach().cpu().numpy() @@ -463,9 +502,36 @@ def forward( mean_outputs.append(mean_output) full_outputs.append(full_output) + if output_attentions: + if not all_attentions: + all_attentions = [[] for _ in range(len(attentions))] + for j in range(len(attentions)): + attentions[j] = attentions[j].detach().cpu().numpy() + attentions[j] = np.array(attentions[j]) + all_attentions[j].append(attentions[j]) + if output_hidden_states: + if not all_hidden_states: + all_hidden_states = [[] for _ in range(len(hidden_states))] + for j in range(len(hidden_states)): + hidden_states[j] = hidden_states[j].detach().cpu().numpy() + hidden_states[j] = np.array(hidden_states[j]) + all_hidden_states[j].append(hidden_states[j]) + mean_outputs = np.concatenate(mean_outputs, axis=0) full_outputs = np.concatenate(full_outputs, axis=0) + if output_attentions: + for j in range(len(all_attentions)): + all_attentions[j] = np.concatenate(all_attentions[j], axis=0) + if output_hidden_states: + for j in range(len(all_hidden_states)): + all_hidden_states[j] = np.concatenate(all_hidden_states[j], axis=0) + + if output_attentions: + print(">> TimesFMModel attentions", len(attentions), attentions[0].shape) + if output_hidden_states: + print(">> TimesFMModel hidden_states", len(hidden_states), hidden_states[0].shape) + if pmap_pad > 0: mean_outputs = mean_outputs[:-pmap_pad, ...] full_outputs = full_outputs[:-pmap_pad, ...] @@ -476,4 +542,21 @@ def forward( if inp_min >= 0 and truncate_negative: mean_outputs = np.maximum(mean_outputs, 0.0) full_outputs = np.maximum(full_outputs, 0.0) - return mean_outputs, full_outputs + + if return_dict: + result = TimesFMOutput() + result.mean_predictions = mean_outputs + result.full_predictions = full_outputs + if output_attentions: + result.attentions = all_attentions + if output_hidden_states: + result.hidden_states = all_hidden_states + + return result + else: + return_tuple = [mean_outputs, full_outputs] + if output_attentions: + return_tuple.append(all_attentions) + if output_hidden_states: + return_tuple.append(all_hidden_states) + return tuple(return_tuple) diff --git a/src/transformers/models/timesfm/timesfm_layers.py b/src/transformers/models/timesfm/timesfm_layers.py index 0ba6f2c6f54d..91fd460a120d 100644 --- a/src/transformers/models/timesfm/timesfm_layers.py +++ b/src/transformers/models/timesfm/timesfm_layers.py @@ -17,10 +17,11 @@ import math from typing import List, Tuple + import numpy as np import torch -from torch import nn import torch.nn.functional as F +from torch import nn def masked_mean_std( @@ -379,6 +380,7 @@ def forward( hidden_states_shape = hidden_states.shape assert len(hidden_states_shape) == 3 + print(">>> TimesFMAttention hidden_states_shape", hidden_states_shape) batch_size, input_len, _ = hidden_states_shape qkv = self.qkv_proj(hidden_states) @@ -461,6 +463,7 @@ def forward( kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: # Self Attention + print(">>> TimesFMDecoderLayer hidden_states", hidden_states.shape) residual = hidden_states hidden_states = self.input_layernorm(hidden_states) scores, hidden_states = self.self_attn( @@ -511,21 +514,32 @@ def forward( paddings: torch.Tensor, kv_write_indices: torch.Tensor | None = None, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] | None = None, + output_attentions: bool = False, + output_hidden_states: bool = False, ) -> torch.Tensor: + print(">>> StackedDecoder hidden_states", hidden_states.shape) padding_mask = convert_paddings_to_mask(paddings, hidden_states.dtype) atten_mask = causal_mask(hidden_states) mask = merge_masks(padding_mask, atten_mask) + all_attentions = [] + all_hidden_states = [] + for i in range(len(self.layers)): layer = self.layers[i] kv_cache = kv_caches[i] if kv_caches is not None else None - _, hidden_states = layer( + scores, hidden_states = layer( hidden_states=hidden_states, mask=mask, paddings=paddings, kv_write_indices=kv_write_indices, kv_cache=kv_cache, ) - return hidden_states + if output_attentions: + all_attentions.append(scores) + if output_hidden_states: + all_hidden_states.append(hidden_states) + + return hidden_states, all_attentions, all_hidden_states class PositionalEmbedding(torch.nn.Module): diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index ebf5df4d2cf3..f22042ec6c71 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -13,37 +13,36 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import numpy as np +import inspect import unittest from typing import List +import numpy as np +import torch + from transformers import TimesFMConfig, is_torch_available from transformers.testing_utils import ( - require_accelerate, - require_sentencepiece, - require_tokenizers, require_torch, - slow, torch_device, ) -from transformers.utils import cached_property, is_torch_fx_available +from transformers.utils import is_torch_fx_available -from ...generation.test_utils import GenerationTesterMixin +# from ...generation.test_utils import GenerationTesterMixin +# define our own GenerationTesters from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin -from ...test_pipeline_mixin import PipelineTesterMixin + + +# from ...test_pipeline_mixin import PipelineTesterMixin if is_torch_fx_available(): - from transformers.utils.fx import symbolic_trace + pass if is_torch_available(): - import torch from transformers import ( - AutoTokenizer, TimesFMModel, ) @@ -68,6 +67,7 @@ def __init__( use_positional_embedding: bool = True, per_core_batch_size: int = 32, initializer_factor: float = 1.0, + is_training: bool = False, ): self.parent = parent self.patch_len = patch_len @@ -78,14 +78,15 @@ def __init__( self.freq_size = freq_size self.model_dim = model_dim self.head_dim = head_dim - self.num_layers = num_layers - self.num_heads = num_heads + self.num_hidden_layers = num_layers + self.num_attention_heads = num_heads self.dropout_rate = dropout_rate self.tolerance = tolerance self.rms_norm_eps = rms_norm_eps self.use_positional_embedding = use_positional_embedding self.per_core_batch_size = per_core_batch_size self.initializer_factor = initializer_factor + self.is_training = is_training def get_large_model_config(self): return TimesFMConfig.from_pretrained("google/timesfm-1.0-200m-pytorch") @@ -100,8 +101,8 @@ def get_config(self): freq_size=self.freq_size, model_dim=self.model_dim, head_dim=self.head_dim, - num_layers=self.num_layers, - num_heads=self.num_heads, + num_layers=self.num_hidden_layers, + num_heads=self.num_attention_heads, dropout_rate=self.dropout_rate, tolerance=self.tolerance, rms_norm_eps=self.rms_norm_eps, @@ -145,7 +146,7 @@ def prepare_config_and_inputs_for_common(self): @require_torch class TimesFMModelTest( - ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase + ModelTesterMixin, unittest.TestCase ): all_model_classes = (TimesFMModel,) if is_torch_available() else () all_generative_model_classes = (TimesFMModel,) if is_torch_available() else () @@ -155,6 +156,7 @@ class TimesFMModelTest( test_resize_embeddings = False test_model_parallel = False is_encoder_decoder = False + test_inputs_embeds = False def setUp(self): self.model_tester = TimesFMModelTester(self) @@ -165,5 +167,89 @@ def test_create_and_run_model(self): model = TimesFMModel(config) model.to(torch_device) model.eval() - results = model.run_model(**inputs_dict) - assert results + results = model(**inputs_dict) + assert results.mean_predictions is not None + + def test_attention_outputs(self): + if not self.has_attentions: + self.skipTest(reason="Model does not output attentions") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + seq_len = getattr(self.model_tester, "seq_length", None) + decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len) + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) + decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + chunk_length = getattr(self.model_tester, "chunk_length", None) + if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): + encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + elif self.is_encoder_decoder: + added_hidden_states = 2 + else: + added_hidden_states = 1 + + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + if chunk_length is not None: + self.assertListEqual( + list(self_attentions[0].shape[-4:]), + [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], + ) + else: + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + + @unittest.skip(reason="Model does not have input embeddings") + def test_model_get_set_embeddings(self): + pass + + # the main input name is `inputs` + def test_model_main_input_name(self): + model_signature = inspect.signature(getattr(TimesFMModel, "forward")) + # The main input is the name of the argument after `self` + observed_main_input_name = list(model_signature.parameters.keys())[1] + self.assertEqual(TimesFMModel.main_input_name, observed_main_input_name) From 001365597c1e089c41e2a7c6f4bbc9c15945174d Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Wed, 6 Nov 2024 15:29:13 -0800 Subject: [PATCH 101/242] fixing unit tests in progress --- .../models/timesfm/configuration_timesfm.py | 9 +++--- .../models/timesfm/modeling_timesfm.py | 32 +++++++++---------- tests/models/timesfm/test_modeling_timesfm.py | 24 +++++++++----- 3 files changed, 37 insertions(+), 28 deletions(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 6e6dc8aec307..0ff463ba270d 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -64,8 +64,8 @@ class TimesFMConfig(PretrainedConfig): The value used to pad the predictions. use_positional_embedding (`bool`, *optional*, defaults to `True`): Whether to add positional embeddings. - per_core_batch_size (`int`, *optional*, defaults to 32): - The batch size per core for data parallelism. + batch_size (`int`, *optional*, defaults to 32): + The batch size. initializer_factor (`float`, *optional*, defaults to 1.0): A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing). @@ -96,7 +96,7 @@ def __init__( quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], pad_val: float = 1123581321.0, use_positional_embedding: bool = True, - per_core_batch_size: int = 32, + batch_size: int = 32, initializer_factor: float = 1.0, **kwargs, ): @@ -114,10 +114,11 @@ def __init__( self.tolerance = tolerance self.rms_norm_eps = rms_norm_eps self.use_positional_embedding = use_positional_embedding - self.per_core_batch_size = per_core_batch_size + self.batch_size = batch_size self.initializer_factor = initializer_factor super().__init__( + is_encoder_decoder=self.is_encoder_decoder, **kwargs, ) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 8757647075f6..f9f0a9a8ce30 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -50,11 +50,14 @@ class TimesFMPreTrainedModel(PreTrainedModel): main_input_name = "inputs" def _init_weights(self, module): + print(">>> TimesFMPreTrainedModel _init_weights") if isinstance(module, nn.Embedding): - nn.init.uniform_(module.weight, a=-0.1, b=0.1) + print(">>> TimesFMPreTrainedModel Embedding std", self.config.initializer_factor) + module.weight.data.normal_(mean=0, std=self.config.initializer_factor) elif isinstance(module, nn.Linear): - nn.init.xavier_uniform_(module.weight) + print(">>> TimesFMPreTrainedModel Linear std", self.config.initializer_factor) + module.weight.data.normal_(mean=0, std=self.config.initializer_factor) if module.bias is not None: nn.init.zeros_(module.bias) @@ -316,10 +319,7 @@ def __init__(self, config: TimesFMConfig): self.model_dims = config.model_dim self.quantiles = config.quantiles self.num_heads = config.num_heads - - self.num_cores = 1 - self.per_core_batch_size = config.per_core_batch_size - self.global_batch_size = config.per_core_batch_size * self.num_cores + self.batch_size = config.batch_size self._horizon_start = self.context_len - self.input_patch_len def _preprocess( @@ -346,8 +346,8 @@ def _preprocess( input_ts, input_padding, inp_freq = [], [], [] pmap_pad = ( - (len(inputs) - 1) // self.global_batch_size + 1 - ) * self.global_batch_size - len(inputs) + (len(inputs) - 1) // self.batch_size + 1 + ) * self.batch_size - len(inputs) for i, ts in enumerate(inputs): input_len = ts.shape[0] @@ -453,14 +453,14 @@ def forward( full_outputs = [] all_attentions = [] all_hidden_states = [] - assert input_ts.shape[0] % self.global_batch_size == 0 - for i in range(input_ts.shape[0] // self.global_batch_size): + assert input_ts.shape[0] % self.batch_size == 0 + for i in range(input_ts.shape[0] // self.batch_size): input_ts_in = torch.from_numpy( np.array( input_ts[ i - * self.global_batch_size : (i + 1) - * self.global_batch_size + * self.batch_size : (i + 1) + * self.batch_size ], dtype=np.float32, ) @@ -469,8 +469,8 @@ def forward( np.array( input_padding[ i - * self.global_batch_size : (i + 1) - * self.global_batch_size + * self.batch_size : (i + 1) + * self.batch_size ], dtype=np.float32, ) @@ -479,8 +479,8 @@ def forward( np.array( inp_freq[ i - * self.global_batch_size : (i + 1) - * self.global_batch_size, + * self.batch_size : (i + 1) + * self.batch_size, :, ], dtype=np.int32, diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index f22042ec6c71..645ad268448e 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -55,18 +55,18 @@ def __init__( context_len: int = 512, horizon_len: int = 128, freq_size: int = 3, - num_layers: int = 20, - model_dim: int = 1280, - head_dim: int = 80, - num_heads: int = 16, + num_layers: int = 4, + model_dim: int = 128, + head_dim: int = 16, + num_heads: int = 4, dropout_rate: float = 0.1, tolerance: float = 1e-6, rms_norm_eps: float = 1e-6, quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], pad_val: float = 1123581321.0, use_positional_embedding: bool = True, - per_core_batch_size: int = 32, - initializer_factor: float = 1.0, + batch_size: int = 32, + initializer_factor: float = 0.0, is_training: bool = False, ): self.parent = parent @@ -84,9 +84,13 @@ def __init__( self.tolerance = tolerance self.rms_norm_eps = rms_norm_eps self.use_positional_embedding = use_positional_embedding - self.per_core_batch_size = per_core_batch_size + self.batch_size = batch_size self.initializer_factor = initializer_factor self.is_training = is_training + + # The size of test input + self.seq_length = context_len // patch_len + self.hidden_size = model_dim def get_large_model_config(self): return TimesFMConfig.from_pretrained("google/timesfm-1.0-200m-pytorch") @@ -107,7 +111,7 @@ def get_config(self): tolerance=self.tolerance, rms_norm_eps=self.rms_norm_eps, use_positional_embedding=self.use_positional_embedding, - per_core_batch_size=self.per_core_batch_size, + batch_size=self.batch_size, initializer_factor=self.initializer_factor, ) @@ -247,6 +251,10 @@ def test_attention_outputs(self): def test_model_get_set_embeddings(self): pass + @unittest.skip(reason="Model does not have head mask") + def test_headmasking(self): + pass + # the main input name is `inputs` def test_model_main_input_name(self): model_signature = inspect.signature(getattr(TimesFMModel, "forward")) From 6419285d9ed187ffb4a480602ee79456a0c86950 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 7 Nov 2024 19:29:56 +0100 Subject: [PATCH 102/242] add post_init --- .../models/timesfm/modeling_timesfm.py | 16 +++++++++++++++- tests/models/timesfm/test_modeling_timesfm.py | 2 +- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index f9f0a9a8ce30..dc9d36908736 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -33,7 +33,15 @@ from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from .configuration_timesfm import TimesFMConfig -from .timesfm_layers import * +from .timesfm_layers import ( + PositionalEmbedding, + ResidualBlock, + RMSNorm, + StackedDecoder, + masked_mean_std, + moving_average, + shift_padded_seq, +) @dataclass @@ -105,6 +113,9 @@ def __init__(self, config: TimesFMConfig): self.position_emb = PositionalEmbedding( embedding_dims=self.config.model_dim, ) + + # Initialize weights and apply final processing + self.post_init() def _forward_transform( self, inputs: torch.Tensor, patched_pads: torch.Tensor @@ -322,6 +333,9 @@ def __init__(self, config: TimesFMConfig): self.batch_size = config.batch_size self._horizon_start = self.context_len - self.input_patch_len + # Initialize weights and apply final processing + self.post_init() + def _preprocess( self, inputs: Sequence[np.array], freq: Sequence[int] ) -> tuple[np.array, np.array, int]: diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index 645ad268448e..8f7853398147 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -87,7 +87,7 @@ def __init__( self.batch_size = batch_size self.initializer_factor = initializer_factor self.is_training = is_training - + # The size of test input self.seq_length = context_len // patch_len self.hidden_size = model_dim From 7cd2e41b5860dee110ea7865b2e7b2ea0911602c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 7 Nov 2024 19:40:01 +0100 Subject: [PATCH 103/242] do not change TimesFMOutput --- .../models/timesfm/modeling_timesfm.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index dc9d36908736..e6ac77c6a418 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -558,15 +558,12 @@ def forward( full_outputs = np.maximum(full_outputs, 0.0) if return_dict: - result = TimesFMOutput() - result.mean_predictions = mean_outputs - result.full_predictions = full_outputs - if output_attentions: - result.attentions = all_attentions - if output_hidden_states: - result.hidden_states = all_hidden_states - - return result + return TimesFMOutput( + mean_predictions=mean_outputs, + full_predictions=full_outputs, + attentions=all_attentions if output_attentions else None, + hidden_states=all_hidden_states if output_hidden_states else None, + ) else: return_tuple = [mean_outputs, full_outputs] if output_attentions: From 47affe8668b834a5988911937ade622298dbab56 Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Wed, 13 Nov 2024 17:11:00 -0800 Subject: [PATCH 104/242] fixing unit tests --- .../models/timesfm/modeling_timesfm.py | 147 ++++++++---------- tests/models/timesfm/test_modeling_timesfm.py | 73 --------- 2 files changed, 69 insertions(+), 151 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index e6ac77c6a418..f34f0b64deb0 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -58,13 +58,10 @@ class TimesFMPreTrainedModel(PreTrainedModel): main_input_name = "inputs" def _init_weights(self, module): - print(">>> TimesFMPreTrainedModel _init_weights") if isinstance(module, nn.Embedding): - print(">>> TimesFMPreTrainedModel Embedding std", self.config.initializer_factor) module.weight.data.normal_(mean=0, std=self.config.initializer_factor) elif isinstance(module, nn.Linear): - print(">>> TimesFMPreTrainedModel Linear std", self.config.initializer_factor) module.weight.data.normal_(mean=0, std=self.config.initializer_factor) if module.bias is not None: nn.init.zeros_(module.bias) @@ -462,84 +459,77 @@ def forward( input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq) print(">>> TimesFMModel input_ts", input_ts.shape) - with torch.no_grad(): - mean_outputs = [] - full_outputs = [] - all_attentions = [] - all_hidden_states = [] - assert input_ts.shape[0] % self.batch_size == 0 - for i in range(input_ts.shape[0] // self.batch_size): - input_ts_in = torch.from_numpy( - np.array( - input_ts[ - i - * self.batch_size : (i + 1) - * self.batch_size - ], - dtype=np.float32, - ) + mean_outputs = [] + full_outputs = [] + all_attentions = [] + all_hidden_states = [] + assert input_ts.shape[0] % self.batch_size == 0 + for i in range(input_ts.shape[0] // self.batch_size): + input_ts_in = torch.from_numpy( + np.array( + input_ts[ + i + * self.batch_size : (i + 1) + * self.batch_size + ], + dtype=np.float32, ) - input_padding_in = torch.from_numpy( - np.array( - input_padding[ - i - * self.batch_size : (i + 1) - * self.batch_size - ], - dtype=np.float32, - ) + ) + input_padding_in = torch.from_numpy( + np.array( + input_padding[ + i + * self.batch_size : (i + 1) + * self.batch_size + ], + dtype=np.float32, ) - inp_freq_in = torch.from_numpy( - np.array( - inp_freq[ - i - * self.batch_size : (i + 1) - * self.batch_size, - :, - ], - dtype=np.int32, - ) - ).long() - mean_output, full_output, attentions, hidden_states = self.decoder.decode( - input_ts=input_ts_in, - paddings=input_padding_in, - freq=inp_freq_in, - horizon_len=self.horizon_len, - return_forecast_on_context=return_forecast_on_context, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, + ) + inp_freq_in = torch.from_numpy( + np.array( + inp_freq[ + i + * self.batch_size : (i + 1) + * self.batch_size, + :, + ], + dtype=np.int32, ) - mean_output = mean_output.detach().cpu().numpy() - full_output = full_output.detach().cpu().numpy() - mean_output = np.array(mean_output) - full_output = np.array(full_output) - mean_outputs.append(mean_output) - full_outputs.append(full_output) - - if output_attentions: - if not all_attentions: - all_attentions = [[] for _ in range(len(attentions))] - for j in range(len(attentions)): - attentions[j] = attentions[j].detach().cpu().numpy() - attentions[j] = np.array(attentions[j]) - all_attentions[j].append(attentions[j]) - if output_hidden_states: - if not all_hidden_states: - all_hidden_states = [[] for _ in range(len(hidden_states))] - for j in range(len(hidden_states)): - hidden_states[j] = hidden_states[j].detach().cpu().numpy() - hidden_states[j] = np.array(hidden_states[j]) - all_hidden_states[j].append(hidden_states[j]) - - mean_outputs = np.concatenate(mean_outputs, axis=0) - full_outputs = np.concatenate(full_outputs, axis=0) + ).long() + mean_output, full_output, attentions, hidden_states = self.decoder.decode( + input_ts=input_ts_in, + paddings=input_padding_in, + freq=inp_freq_in, + horizon_len=self.horizon_len, + return_forecast_on_context=return_forecast_on_context, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + mean_outputs.append(mean_output) + full_outputs.append(full_output) + + if output_attentions: + if not all_attentions: + all_attentions = [[] for _ in range(len(attentions))] + for j in range(len(attentions)): + attentions[j] = attentions[j] + all_attentions[j].append(attentions[j]) + if output_hidden_states: + if not all_hidden_states: + all_hidden_states = [[] for _ in range(len(hidden_states))] + for j in range(len(hidden_states)): + hidden_states[j] = hidden_states[j] + all_hidden_states[j].append(hidden_states[j]) + + mean_outputs = torch.cat(mean_outputs, axis=0) + full_outputs = torch.cat(full_outputs, axis=0) if output_attentions: for j in range(len(all_attentions)): - all_attentions[j] = np.concatenate(all_attentions[j], axis=0) + all_attentions[j] = torch.cat(all_attentions[j], axis=0) if output_hidden_states: for j in range(len(all_hidden_states)): - all_hidden_states[j] = np.concatenate(all_hidden_states[j], axis=0) + all_hidden_states[j] = torch.cat(all_hidden_states[j], axis=0) if output_attentions: print(">> TimesFMModel attentions", len(attentions), attentions[0].shape) @@ -554,8 +544,8 @@ def forward( mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] if inp_min >= 0 and truncate_negative: - mean_outputs = np.maximum(mean_outputs, 0.0) - full_outputs = np.maximum(full_outputs, 0.0) + mean_outputs = torch.maximum(mean_outputs, 0.0) + full_outputs = torch.maximum(full_outputs, 0.0) if return_dict: return TimesFMOutput( @@ -565,9 +555,10 @@ def forward( hidden_states=all_hidden_states if output_hidden_states else None, ) else: - return_tuple = [mean_outputs, full_outputs] - if output_attentions: - return_tuple.append(all_attentions) + return_tuple = [] if output_hidden_states: return_tuple.append(all_hidden_states) + if output_attentions: + return_tuple.append(all_attentions) + return_tuple += [mean_outputs, full_outputs] return tuple(return_tuple) diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index 8f7853398147..c928cf9aca8e 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -174,79 +174,6 @@ def test_create_and_run_model(self): results = model(**inputs_dict) assert results.mean_predictions is not None - def test_attention_outputs(self): - if not self.has_attentions: - self.skipTest(reason="Model does not output attentions") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.return_dict = True - - seq_len = getattr(self.model_tester, "seq_length", None) - decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len) - encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) - decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length) - encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) - chunk_length = getattr(self.model_tester, "chunk_length", None) - if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): - encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes - - for model_class in self.all_model_classes: - inputs_dict["output_attentions"] = True - inputs_dict["output_hidden_states"] = False - config.return_dict = True - model = model_class(config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) - - # check that output_attentions also work using config - del inputs_dict["output_attentions"] - config.output_attentions = True - model = model_class(config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) - - out_len = len(outputs) - - # Check attention is always last and order is fine - inputs_dict["output_attentions"] = True - inputs_dict["output_hidden_states"] = True - model = model_class(config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - - if hasattr(self.model_tester, "num_hidden_states_types"): - added_hidden_states = self.model_tester.num_hidden_states_types - elif self.is_encoder_decoder: - added_hidden_states = 2 - else: - added_hidden_states = 1 - - self.assertEqual(out_len + added_hidden_states, len(outputs)) - - self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - - self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) - if chunk_length is not None: - self.assertListEqual( - list(self_attentions[0].shape[-4:]), - [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], - ) - else: - self.assertListEqual( - list(self_attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], - ) - @unittest.skip(reason="Model does not have input embeddings") def test_model_get_set_embeddings(self): pass From bbf738cf333f5f81a0d2a6b68319a30a4fe5562b Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Thu, 14 Nov 2024 17:48:25 -0800 Subject: [PATCH 105/242] all unit tests passed --- .../models/timesfm/modeling_timesfm.py | 124 +++++------------- 1 file changed, 31 insertions(+), 93 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index f34f0b64deb0..f33c169f2555 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -308,7 +308,7 @@ def decode( # `full_outputs` indexing starts at the forecast horizon. full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0:horizon_len, :] - return full_outputs[:, :, 0], full_outputs, all_attentions, all_hidden_states + return full_outputs[:, :, 0], full_outputs, fprop_outputs, all_attentions, all_hidden_states class TimesFMModel(TimesFMPreTrainedModel): @@ -356,10 +356,6 @@ def _preprocess( print(">>> TimesFMModel _preprocess", len(inputs), inputs[0].shape) input_ts, input_padding, inp_freq = [], [], [] - pmap_pad = ( - (len(inputs) - 1) // self.batch_size + 1 - ) * self.batch_size - len(inputs) - for i, ts in enumerate(inputs): input_len = ts.shape[0] padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float) @@ -379,19 +375,12 @@ def _preprocess( input_padding.append(padding) inp_freq.append(freq[i]) - # Padding the remainder batch. - for _ in range(pmap_pad): - input_ts.append(input_ts[-1]) - input_padding.append(input_padding[-1]) - inp_freq.append(inp_freq[-1]) - print(">>> TimesFMModel input_ts", len(input_ts), input_ts[0].shape) return ( np.stack(input_ts, axis=0), np.stack(input_padding, axis=0), np.array(inp_freq).astype(np.int32).reshape(-1, 1), - pmap_pad, ) def forward( @@ -457,88 +446,36 @@ def forward( if output_hidden_states is None: output_hidden_states = self.config.output_hidden_states - input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq) + input_ts, input_padding, inp_freq = self._preprocess(inputs, freq) print(">>> TimesFMModel input_ts", input_ts.shape) - mean_outputs = [] - full_outputs = [] - all_attentions = [] - all_hidden_states = [] - assert input_ts.shape[0] % self.batch_size == 0 - for i in range(input_ts.shape[0] // self.batch_size): - input_ts_in = torch.from_numpy( - np.array( - input_ts[ - i - * self.batch_size : (i + 1) - * self.batch_size - ], - dtype=np.float32, - ) + + input_ts_in = torch.from_numpy( + np.array( + input_ts, + dtype=np.float32, ) - input_padding_in = torch.from_numpy( - np.array( - input_padding[ - i - * self.batch_size : (i + 1) - * self.batch_size - ], - dtype=np.float32, - ) + ) + input_padding_in = torch.from_numpy( + np.array( + input_padding, + dtype=np.float32, ) - inp_freq_in = torch.from_numpy( - np.array( - inp_freq[ - i - * self.batch_size : (i + 1) - * self.batch_size, - :, - ], - dtype=np.int32, - ) - ).long() - mean_output, full_output, attentions, hidden_states = self.decoder.decode( - input_ts=input_ts_in, - paddings=input_padding_in, - freq=inp_freq_in, - horizon_len=self.horizon_len, - return_forecast_on_context=return_forecast_on_context, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, + ) + inp_freq_in = torch.from_numpy( + np.array( + inp_freq, + dtype=np.int32, ) - mean_outputs.append(mean_output) - full_outputs.append(full_output) - - if output_attentions: - if not all_attentions: - all_attentions = [[] for _ in range(len(attentions))] - for j in range(len(attentions)): - attentions[j] = attentions[j] - all_attentions[j].append(attentions[j]) - if output_hidden_states: - if not all_hidden_states: - all_hidden_states = [[] for _ in range(len(hidden_states))] - for j in range(len(hidden_states)): - hidden_states[j] = hidden_states[j] - all_hidden_states[j].append(hidden_states[j]) - - mean_outputs = torch.cat(mean_outputs, axis=0) - full_outputs = torch.cat(full_outputs, axis=0) - - if output_attentions: - for j in range(len(all_attentions)): - all_attentions[j] = torch.cat(all_attentions[j], axis=0) - if output_hidden_states: - for j in range(len(all_hidden_states)): - all_hidden_states[j] = torch.cat(all_hidden_states[j], axis=0) - - if output_attentions: - print(">> TimesFMModel attentions", len(attentions), attentions[0].shape) - if output_hidden_states: - print(">> TimesFMModel hidden_states", len(hidden_states), hidden_states[0].shape) - - if pmap_pad > 0: - mean_outputs = mean_outputs[:-pmap_pad, ...] - full_outputs = full_outputs[:-pmap_pad, ...] + ).long() + mean_outputs, full_outputs, last_hidden_state, all_attentions, all_hidden_states = self.decoder.decode( + input_ts=input_ts_in, + paddings=input_padding_in, + freq=inp_freq_in, + horizon_len=self.horizon_len, + return_forecast_on_context=return_forecast_on_context, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) if window_size is not None: mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] @@ -549,13 +486,14 @@ def forward( if return_dict: return TimesFMOutput( - mean_predictions=mean_outputs, - full_predictions=full_outputs, + last_hidden_state=last_hidden_state, attentions=all_attentions if output_attentions else None, hidden_states=all_hidden_states if output_hidden_states else None, + mean_predictions=mean_outputs, + full_predictions=full_outputs, ) else: - return_tuple = [] + return_tuple = [last_hidden_state] if output_hidden_states: return_tuple.append(all_hidden_states) if output_attentions: From bb2a850cd97c1e2d028a06154566be0e3840509c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 24 Nov 2024 16:07:49 +0100 Subject: [PATCH 106/242] remove timesfm_layers --- src/transformers/models/timesfm/__init__.py | 4 +- .../models/timesfm/modeling_timesfm.py | 605 ++++++++++++++++-- .../models/timesfm/timesfm_layers.py | 597 ----------------- tests/models/timesfm/test_modeling_timesfm.py | 1 - 4 files changed, 550 insertions(+), 657 deletions(-) delete mode 100644 src/transformers/models/timesfm/timesfm_layers.py diff --git a/src/transformers/models/timesfm/__init__.py b/src/transformers/models/timesfm/__init__.py index 82bbb6be22ce..6592a5b1620e 100644 --- a/src/transformers/models/timesfm/__init__.py +++ b/src/transformers/models/timesfm/__init__.py @@ -51,6 +51,4 @@ else: import sys - sys.modules[__name__] = _LazyModule( - __name__, globals()["__file__"], _import_structure, module_spec=__spec__ - ) + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index f33c169f2555..9e0959b8e12f 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -21,27 +21,19 @@ # - PreTrainedModel for the models (it-self a sub-class of nn.Module) #################################################### - import logging +import math from dataclasses import dataclass -from typing import Any, Sequence +from typing import Any, List, Sequence, Tuple import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from .configuration_timesfm import TimesFMConfig -from .timesfm_layers import ( - PositionalEmbedding, - ResidualBlock, - RMSNorm, - StackedDecoder, - masked_mean_std, - moving_average, - shift_padded_seq, -) @dataclass @@ -50,6 +42,521 @@ class TimesFMOutput(BaseModelOutput): full_predictions: np.ndarray = None +class TimesFMTransformerMLP(nn.Module): + """Pax transformer MLP in pytorch.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + ): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size) + self.down_proj = nn.Linear(intermediate_size, hidden_size) + self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6) + + def forward(self, x, paddings=None): + gate_inp = self.layer_norm(x) + gate = self.gate_proj(gate_inp) + gate = F.relu(gate) + outputs = self.down_proj(gate) + if paddings is not None: + outputs = outputs * (1.0 - paddings[:, :, None]) + return outputs + x + + +class TimesFMResidualBlock(nn.Module): + """TimesFM residual block.""" + + def __init__( + self, + input_dims, + hidden_dims, + output_dims, + ): + super().__init__() + self.input_dims = input_dims + self.hidden_dims = hidden_dims + self.output_dims = output_dims + + # Hidden Layer + self.hidden_layer = nn.Sequential( + nn.Linear(input_dims, hidden_dims), + nn.SiLU(), + ) + + # Output Layer + self.output_layer = nn.Linear(hidden_dims, output_dims) + # Residual Layer + self.residual_layer = nn.Linear(input_dims, output_dims) + + def forward(self, x): + hidden = self.hidden_layer(x) + output = self.output_layer(hidden) + residual = self.residual_layer(x) + return output + residual + + +class TimesFMRMSNorm(torch.nn.Module): + """Pax rms norm in pytorch.""" + + def __init__( + self, + dim: int, + eps: float = 1e-6, + add_unit_offset: bool = False, + ): + super().__init__() + self.eps = eps + self.add_unit_offset = add_unit_offset + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + if self.add_unit_offset: + output = output * (1 + self.weight.float()) + else: + output = output * self.weight.float() + return output.type_as(x) + + +class TimesFMPositionalEmbedding(nn.Module): + """Generates position embedding for a given 1-d sequence. + + Attributes: + min_timescale: Start of the geometric index. Determines the periodicity of + the added signal. + max_timescale: End of the geometric index. Determines the frequency of the + added signal. + embedding_dims: Dimension of the embedding to be generated. + """ + + def __init__( + self, + embedding_dims: int, + min_timescale: int = 1, + max_timescale: int = 10_000, + ) -> None: + super().__init__() + self.min_timescale = min_timescale + self.max_timescale = max_timescale + self.embedding_dims = embedding_dims + + def forward(self, seq_length=None, position=None): + """Generates a Tensor of sinusoids with different frequencies. + + Args: + seq_length: an optional Python int defining the output sequence length. + if the `position` argument is specified. + position: [B, seq_length], optional position for each token in the + sequence, only required when the sequence is packed. + + Returns: + [B, seqlen, D] if `position` is specified, else [1, seqlen, D] + """ + if position is None: + assert seq_length is not None + # [1, seqlen] + position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0) + else: + assert position.ndim == 2, position.shape + + num_timescales = self.embedding_dims // 2 + log_timescale_increment = math.log(float(self.max_timescale) / float(self.min_timescale)) / max( + num_timescales - 1, 1 + ) + inv_timescales = self.min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(2) * inv_timescales.unsqueeze(0).unsqueeze(0) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) + # Padding to ensure correct embedding dimension + signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) + return signal + + +class TimesFMAttention(nn.Module): + """Implements the attention used in TimesFM.""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + ): + super().__init__() + + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.hidden_size = hidden_size + self.head_dim = head_dim + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = nn.Parameter( + torch.empty((self.head_dim,), dtype=torch.float32), + ) + + self.qkv_proj = nn.Linear( + self.hidden_size, + (self.num_heads + 2 * self.num_kv_heads) * self.head_dim, + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size) + + def _per_dim_scaling(self, query: torch.Tensor) -> torch.Tensor: + # [batch_size, n_local_heads, input_len, head_dim] + r_softplus_0 = 1.442695041 + softplus_func = torch.nn.Softplus() + scale = r_softplus_0 / math.sqrt(self.head_dim) + scale = scale * softplus_func(self.scaling) + return query * scale[None, None, None, :] + + def forward( + self, + hidden_states: torch.Tensor, + mask: torch.Tensor, + kv_write_indices: torch.Tensor | None = None, + kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + hidden_states_shape = hidden_states.shape + assert len(hidden_states_shape) == 3 + + batch_size, input_len, _ = hidden_states_shape + + qkv = self.qkv_proj(hidden_states) + xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + xq = xq.view(batch_size, -1, self.num_heads, self.head_dim) + xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim) + xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim) + xq = self._per_dim_scaling(xq) + + # Write new kv cache. + # [batch_size, input_len, n_local_kv_heads, head_dim] + if kv_cache is not None and kv_write_indices is not None: + k_cache, v_cache = kv_cache + k_cache.index_copy_(1, kv_write_indices, xk) + v_cache.index_copy_(1, kv_write_indices, xv) + + key = k_cache + value = v_cache + else: + key = xk + value = xv + if self.num_kv_heads != self.num_heads: + # [batch_size, max_seq_len, n_local_heads, head_dim] + key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) + value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2) + + # [batch_size, n_local_heads, input_len, head_dim] + q = xq.transpose(1, 2) + # [batch_size, n_local_heads, max_seq_len, head_dim] + k = key.transpose(1, 2) + v = value.transpose(1, 2) + + # [batch_size, n_local_heads, input_len, max_seq_len] + scores = torch.matmul(q, k.transpose(2, 3)) + scores = scores + mask + scores = F.softmax(scores.float(), dim=-1).type_as(q) + + # [batch_size, n_local_heads, input_len, head_dim] + output = torch.matmul(scores, v) + # return scores, output.transpose(1, 2).contiguous() + + # [batch_size, input_len, hidden_dim] + output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) + output = self.o_proj(output) + return scores, output + + +class TimesFMDecoderLayer(nn.Module): + """Transformer layer.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + rms_norm_eps: float = 1e-6, + ): + super().__init__() + self.self_attn = TimesFMAttention( + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + ) + self.mlp = TimesFMTransformerMLP( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + ) + self.input_layernorm = TimesFMRMSNorm(hidden_size, eps=rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + mask: torch.Tensor, + paddings: torch.Tensor, + kv_write_indices: torch.Tensor | None = None, + kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + scores, hidden_states = self.self_attn( + hidden_states=hidden_states, + mask=mask, + kv_write_indices=kv_write_indices, + kv_cache=kv_cache, + ) + hidden_states = residual + hidden_states + + # MLP + hidden_states = self.mlp(hidden_states, paddings=paddings) + + return scores, hidden_states + + +class TimesFMStackedDecoder(nn.Module): + """Stacked transformer layer.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + num_layers: int, + rms_norm_eps: float = 1e-6, + ): + super().__init__() + + self.layers = nn.ModuleList() + for _ in range(num_layers): + self.layers.append( + TimesFMDecoderLayer( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + rms_norm_eps=rms_norm_eps, + ) + ) + + def forward( + self, + hidden_states: torch.Tensor, + paddings: torch.Tensor, + kv_write_indices: torch.Tensor | None = None, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] | None = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + ) -> torch.Tensor: + padding_mask = timesfm_convert_paddings_to_mask(paddings, hidden_states.dtype) + atten_mask = timesfm_causal_mask(hidden_states) + mask = timesfm_merge_masks(padding_mask, atten_mask) + all_attentions = [] + all_hidden_states = [] + + for i in range(len(self.layers)): + layer = self.layers[i] + kv_cache = kv_caches[i] if kv_caches is not None else None + scores, hidden_states = layer( + hidden_states=hidden_states, + mask=mask, + paddings=paddings, + kv_write_indices=kv_write_indices, + kv_cache=kv_cache, + ) + if output_attentions: + all_attentions.append(scores) + if output_hidden_states: + all_hidden_states.append(hidden_states) + + return hidden_states, all_attentions, all_hidden_states + + +# Move utility functions here +def timesfm_masked_mean_std(inputs: torch.Tensor, padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Calculates mean and standard deviation of `inputs` across axis 1. + + It excludes values where `padding` is 1. + + Args: + inputs: A PyTorch tensor of shape [b, n, p]. + padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1. + + Returns: + A tuple containing the mean and standard deviation. + We return the statistics of the first patch with more than three non-padded values. + """ + + # Selecting the first patch with more than 3 unpadded values. + def _get_patch_index(arr: torch.Tensor): + indices = torch.argmax((arr >= 3).to(torch.int32), dim=1) + row_sum = (arr >= 3).to(torch.int32).sum(dim=1) + return torch.where(row_sum == 0, arr.shape[1] - 1, indices) + + pad_sum = torch.sum(1 - padding, dim=2) + patch_indices = _get_patch_index(pad_sum) + bidxs = torch.arange(inputs.shape[0]) + + arr = inputs[bidxs, patch_indices, :] + pad = padding[bidxs, patch_indices, :] + + # Create a mask where padding is 0 + mask = 1 - pad + + # Calculate the number of valid elements + num_valid_elements = torch.sum(mask, dim=1) + num_valid_elements = torch.where( + num_valid_elements == 0, + torch.tensor(1, dtype=num_valid_elements.dtype, device=num_valid_elements.device), + num_valid_elements, + ) + + # Calculate the masked sum and squared sum + masked_sum = torch.sum(arr * mask, dim=1) + masked_squared_sum = torch.sum((arr * mask) ** 2, dim=1) + + # Calculate the masked mean and standard deviation + masked_mean = masked_sum / num_valid_elements + masked_var = masked_squared_sum / num_valid_elements - masked_mean**2 + masked_var = torch.where( + masked_var < 0.0, + torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device), + masked_var, + ) + masked_std = torch.sqrt(masked_var) + + return masked_mean, masked_std + + +def timesfm_shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor: + """Shifts rows of seq based on the first 0 in each row of the mask. + + Args: + mask: mask tensor of shape [B, N] + seq: seq tensor of shape [B, N, P] + + Returns: + The shifted sequence. + """ + batch_size, num_seq, feature_dim = seq.shape + + new_mask: torch.BoolTensor = mask == 0 + + # Use argmax to find the first True value in each row + indices = new_mask.to(torch.int32).argmax(dim=1) + + # Handle rows with all zeros + indices[~new_mask.any(dim=1)] = -1 + + # Create index ranges for each sequence in the batch + idx_range = torch.arange(num_seq).to(seq.device).unsqueeze(0).unsqueeze(-1).expand(batch_size, -1, feature_dim) + + # Calculate shifted indices for each element in each sequence + shifted_idx = (idx_range - indices[:, None, None]) % num_seq + + # Gather values from seq using shifted indices + shifted_seq = seq.gather(1, shifted_idx) + + return shifted_seq + + +def timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]: + """Calculates the moving average using NumPy's convolution function.""" + # Pad with zeros to handle initial window positions + arr_padded = np.pad(arr, (window_size - 1, 0), "constant") + smoothed_arr = np.convolve(arr_padded, np.ones(window_size), "valid") / window_size + return [smoothed_arr, arr - smoothed_arr] + + +def timesfm_get_large_negative_number(dtype: torch.dtype) -> torch.Tensor: + """Returns a large negative value for the given dtype.""" + if dtype.is_floating_point: + dtype_max = torch.finfo(dtype).max + else: + dtype_max = torch.iinfo(dtype).max + return torch.tensor(-0.7 * dtype_max, dtype=dtype) + + +def timesfm_causal_mask(input_t: torch.Tensor) -> torch.Tensor: + """Computes and returns causal mask. + + Args: + input_t: A torch.Tensor of shape [B, T, D]. + + Returns: + An attention_mask torch.Tensor of shape [1, 1, T, T]. Attention mask has + already been converted to large negative values. + """ + assert input_t.dtype.is_floating_point, input_t.dtype + large_negative_number = timesfm_get_large_negative_number(input_t.dtype) + t = input_t.shape[1] + col_idx = torch.arange(t).unsqueeze(0).repeat(t, 1) + row_idx = torch.arange(t).unsqueeze(1).repeat(1, t) + mask = (row_idx < col_idx).to(input_t.dtype) * large_negative_number + return mask.unsqueeze(0).unsqueeze(0).to(input_t.device) # Equivalent to jnp.newaxis + + +def timesfm_convert_paddings_to_mask(paddings: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Converts binary paddings to a logit mask ready to add to attention matrix. + + Args: + paddings: binary torch.Tensor of shape [B, T], with 1 denoting padding + token. + dtype: data type of the input. + + Returns: + A torch.Tensor of shape [B, 1, 1, T] ready to add to attention logits. + """ + attention_mask = paddings.detach().clone() + attention_mask = attention_mask[:, None, None, :] # Equivalent to jnp.newaxis + attention_mask *= timesfm_get_large_negative_number(dtype) + return attention_mask + + +def timesfm_merge_masks(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """Merges 2 masks. + + logscale mask is expected but 0/1 mask is also fine. + + Args: + a: torch.Tensor of shape [1|B, 1, 1|T, S]. + b: torch.Tensor of shape [1|B, 1, 1|T, S]. + + Returns: + torch.Tensor of shape [1|B, 1, 1|T, S]. + """ + + def expand_t(key_mask): + query_mask = key_mask.transpose(-1, -2) # Equivalent of jnp.transpose + return torch.minimum(query_mask, key_mask) + + if a.shape[2] != b.shape[2]: + if a.shape[2] == 1: + a = expand_t(a) + else: + assert b.shape[2] == 1 + b = expand_t(b) + + assert a.shape[1:] == b.shape[1:], f"a.shape={a.shape}, b.shape={b.shape}." + return torch.minimum(a, b) # Element-wise minimum, similar to jnp.minimum + + class TimesFMPreTrainedModel(PreTrainedModel): """handles the loading for all models.""" @@ -70,10 +577,10 @@ def _init_weights(self, module): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) - elif isinstance(module, RMSNorm): + elif isinstance(module, TimesFMRMSNorm): nn.init.zeros_(module.weight) - elif isinstance(module, PositionalEmbedding): + elif isinstance(module, TimesFMPositionalEmbedding): pass @@ -84,20 +591,18 @@ def __init__(self, config: TimesFMConfig): super().__init__(config) self.config = config - self.input_ff_layer = ResidualBlock( + self.input_ff_layer = TimesFMResidualBlock( input_dims=2 * config.patch_len, output_dims=config.model_dim, hidden_dims=config.model_dim, ) - self.freq_emb = nn.Embedding( - num_embeddings=config.freq_size, embedding_dim=config.model_dim - ) - self.horizon_ff_layer = ResidualBlock( + self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.model_dim) + self.horizon_ff_layer = TimesFMResidualBlock( input_dims=config.model_dim, output_dims=config.horizon_len * (1 + len(config.quantiles)), hidden_dims=config.model_dim, ) - self.stacked_transformer = StackedDecoder( + self.stacked_transformer = TimesFMStackedDecoder( hidden_size=self.config.model_dim, intermediate_size=self.config.model_dim, num_heads=self.config.num_heads, @@ -107,10 +612,10 @@ def __init__(self, config: TimesFMConfig): rms_norm_eps=self.config.rms_norm_eps, ) if self.config.use_positional_embedding: - self.position_emb = PositionalEmbedding( + self.position_emb = TimesFMPositionalEmbedding( embedding_dims=self.config.model_dim, ) - + # Initialize weights and apply final processing self.post_init() @@ -118,7 +623,7 @@ def _forward_transform( self, inputs: torch.Tensor, patched_pads: torch.Tensor ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """Input is of shape [B, N, P].""" - mu, sigma = masked_mean_std(inputs, patched_pads) + mu, sigma = timesfm_masked_mean_std(inputs, patched_pads) sigma = torch.where( sigma < self.config.tolerance, torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device), @@ -129,16 +634,12 @@ def _forward_transform( outputs = (inputs - mu[:, None, None]) / sigma[:, None, None] outputs = torch.where( torch.abs(inputs - self.config.pad_val) < self.config.tolerance, - torch.tensor( - self.config.pad_val, dtype=outputs.dtype, device=outputs.device - ), + torch.tensor(self.config.pad_val, dtype=outputs.dtype, device=outputs.device), outputs, ) return outputs, (mu, sigma) - def _reverse_transform( - self, outputs: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor] - ) -> torch.Tensor: + def _reverse_transform(self, outputs: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: """Output is of shape [B, N, P, Q].""" mu, sigma = stats return outputs * sigma[:, None, None, None] + mu[:, None, None, None] @@ -174,19 +675,15 @@ def _preprocess_input( # B x N x D patched_inputs = patched_inputs * (1.0 - patched_pads) - print(">>> PatchedDecoder patched_inputs", patched_inputs.shape) concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1) - print(">>> PatchedDecoder concat_inputs", concat_inputs.shape) model_input = self.input_ff_layer(concat_inputs) # A patch should not be padded even if there is at least one zero. - patched_padding = torch.min(patched_pads, dim=-1)[ - 0 - ] # Get the values from the min result + patched_padding = torch.min(patched_pads, dim=-1)[0] # Get the values from the min result if self.config.use_positional_embedding: pos_emb = self.position_emb(model_input.shape[1]).to(model_input.device) pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0) - pos_emb = shift_padded_seq(patched_padding, pos_emb) + pos_emb = timesfm_shift_padded_seq(patched_padding, pos_emb) model_input += pos_emb return model_input, patched_padding, stats, patched_inputs @@ -216,7 +713,6 @@ def forward( output_attentions: bool = False, output_hidden_states: bool = False, ) -> torch.Tensor: - print(">>> PatchedDecoder input_ts", input_ts.shape) num_outputs = len(self.config.quantiles) + 1 model_input, patched_padding, stats, _ = self._preprocess_input( input_ts=input_ts, @@ -225,8 +721,12 @@ def forward( f_emb = self.freq_emb(freq) # B x 1 x D model_input += f_emb - print(">>> PatchedDecoder model_input", model_input.shape) - model_output, all_attentions, all_hidden_states = self.stacked_transformer(model_input, patched_padding, output_attentions=output_attentions, output_hidden_states=output_hidden_states) + model_output, all_attentions, all_hidden_states = self.stacked_transformer( + model_input, + patched_padding, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) if output_hidden_states: all_hidden_states = [model_input] + all_hidden_states @@ -281,14 +781,18 @@ def decode( current_padding = paddings[:, 0 : final_out.shape[1]] input_ts = final_out[:, -max_len:] input_padding = current_padding[:, -max_len:] - fprop_outputs, all_attentions, all_hidden_states = self.forward(input_ts, input_padding, freq, output_attentions=output_attentions, output_hidden_states=output_hidden_states) + fprop_outputs, all_attentions, all_hidden_states = self.forward( + input_ts, + input_padding, + freq, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) if return_forecast_on_context and step_index == 0: # For the first decodings step, collect the model forecast on the # context except the unavailable first input batch forecast. new_full_ts = fprop_outputs[:, :-1, : self.config.patch_len, :] - new_full_ts = fprop_outputs.view( - new_full_ts.size(0), -1, new_full_ts.size(3) - ) + new_full_ts = fprop_outputs.view(new_full_ts.size(0), -1, new_full_ts.size(3)) full_outputs.append(new_full_ts) @@ -333,9 +837,7 @@ def __init__(self, config: TimesFMConfig): # Initialize weights and apply final processing self.post_init() - def _preprocess( - self, inputs: Sequence[np.array], freq: Sequence[int] - ) -> tuple[np.array, np.array, int]: + def _preprocess(self, inputs: Sequence[np.array], freq: Sequence[int]) -> tuple[np.array, np.array, int]: """Formats and pads raw inputs to feed into the model. This function both pads each time series to match the context length, and @@ -353,7 +855,6 @@ def _preprocess( - the number of padded examples for SPMD so that each core has the same number (a multiple of `batch_size`) of examples. """ - print(">>> TimesFMModel _preprocess", len(inputs), inputs[0].shape) input_ts, input_padding, inp_freq = [], [], [] for i, ts in enumerate(inputs): @@ -361,12 +862,8 @@ def _preprocess( padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float) if input_len < self.context_len: num_front_pad = self.context_len - input_len - ts = np.concatenate( - [np.zeros(shape=(num_front_pad,), dtype=float), ts], axis=0 - ) - padding = np.concatenate( - [np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0 - ) + ts = np.concatenate([np.zeros(shape=(num_front_pad,), dtype=float), ts], axis=0) + padding = np.concatenate([np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0) elif input_len > self.context_len: ts = ts[-self.context_len :] padding = padding[-(self.context_len + self.horizon_len) :] @@ -375,8 +872,6 @@ def _preprocess( input_padding.append(padding) inp_freq.append(freq[i]) - print(">>> TimesFMModel input_ts", len(input_ts), input_ts[0].shape) - return ( np.stack(input_ts, axis=0), np.stack(input_padding, axis=0), @@ -428,13 +923,12 @@ def forward( else: fcontext_len = forecast_context_len inputs = [np.array(ts)[-fcontext_len:] for ts in inputs] - print(">>> TimesFMModel forward", len(inputs), inputs[0].shape) inp_min = np.min([np.min(ts) for ts in inputs]) if window_size is not None: new_inputs = [] for ts in inputs: - new_inputs.extend(moving_average(ts, window_size)) + new_inputs.extend(timesfm_moving_average(ts, window_size)) inputs = new_inputs if freq is None: @@ -447,7 +941,6 @@ def forward( output_hidden_states = self.config.output_hidden_states input_ts, input_padding, inp_freq = self._preprocess(inputs, freq) - print(">>> TimesFMModel input_ts", input_ts.shape) input_ts_in = torch.from_numpy( np.array( diff --git a/src/transformers/models/timesfm/timesfm_layers.py b/src/transformers/models/timesfm/timesfm_layers.py deleted file mode 100644 index 91fd460a120d..000000000000 --- a/src/transformers/models/timesfm/timesfm_layers.py +++ /dev/null @@ -1,597 +0,0 @@ -# coding=utf-8 -# Copyright 2024 Google LLC and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Pytorch version of patched decoder.""" - - -import math -from typing import List, Tuple - -import numpy as np -import torch -import torch.nn.functional as F -from torch import nn - - -def masked_mean_std( - inputs: torch.Tensor, padding: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor]: - """Calculates mean and standard deviation of `inputs` across axis 1. - - It excludes values where `padding` is 1. - - Args: - inputs: A PyTorch tensor of shape [b, n, p]. - padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1. - - Returns: - A tuple containing the mean and standard deviation. - We return the statistics of the first patch with more than three non-padded values. - """ - - # Selecting the first patch with more than 3 unpadded values. - def _get_patch_index(arr: torch.Tensor): - indices = torch.argmax((arr >= 3).to(torch.int32), dim=1) - row_sum = (arr >= 3).to(torch.int32).sum(dim=1) - return torch.where(row_sum == 0, arr.shape[1] - 1, indices) - - pad_sum = torch.sum(1 - padding, dim=2) - patch_indices = _get_patch_index(pad_sum) - bidxs = torch.arange(inputs.shape[0]) - - arr = inputs[bidxs, patch_indices, :] - pad = padding[bidxs, patch_indices, :] - - # Create a mask where padding is 0 - mask = 1 - pad - - # Calculate the number of valid elements - num_valid_elements = torch.sum(mask, dim=1) - num_valid_elements = torch.where( - num_valid_elements == 0, - torch.tensor( - 1, dtype=num_valid_elements.dtype, device=num_valid_elements.device - ), - num_valid_elements, - ) - - # Calculate the masked sum and squared sum - masked_sum = torch.sum(arr * mask, dim=1) - masked_squared_sum = torch.sum((arr * mask) ** 2, dim=1) - - # Calculate the masked mean and standard deviation - masked_mean = masked_sum / num_valid_elements - masked_var = masked_squared_sum / num_valid_elements - masked_mean**2 - masked_var = torch.where( - masked_var < 0.0, - torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device), - masked_var, - ) - masked_std = torch.sqrt(masked_var) - - return masked_mean, masked_std - - -def shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor: - """Shifts rows of seq based on the first 0 in each row of the mask. - - Args: - mask: mask tensor of shape [B, N] - seq: seq tensor of shape [B, N, P] - - Returns: - The shifted sequence. - """ - batch_size, num_seq, feature_dim = seq.shape - - new_mask: torch.BoolTensor = mask == 0 - - # Use argmax to find the first True value in each row - indices = new_mask.to(torch.int32).argmax(dim=1) - - # Handle rows with all zeros - indices[~new_mask.any(dim=1)] = -1 - - # Create index ranges for each sequence in the batch - idx_range = ( - torch.arange(num_seq) - .to(seq.device) - .unsqueeze(0) - .unsqueeze(-1) - .expand(batch_size, -1, feature_dim) - ) - - # Calculate shifted indices for each element in each sequence - shifted_idx = (idx_range - indices[:, None, None]) % num_seq - - # Gather values from seq using shifted indices - shifted_seq = seq.gather(1, shifted_idx) - - return shifted_seq - - -def get_large_negative_number(dtype: torch.dtype) -> torch.Tensor: - """Returns a large negative value for the given dtype.""" - if dtype.is_floating_point: - dtype_max = torch.finfo(dtype).max - else: - dtype_max = torch.iinfo(dtype).max - return torch.tensor(-0.7 * dtype_max, dtype=dtype) - - -def apply_mask_to_logits(logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: - """Applies a floating-point mask to a set of logits. - - Args: - logits: A torch.Tensor of logit values. - mask: A torch.Tensor (float32) of mask values with the encoding described - in the function documentation. - - Returns: - Masked logits. - """ - - min_value = get_large_negative_number(logits.dtype) - - return torch.where((mask >= min_value * 0.5), logits, min_value) - - -def convert_paddings_to_mask( - paddings: torch.Tensor, dtype: torch.dtype = torch.float32 -) -> torch.Tensor: - """Converts binary paddings to a logit mask ready to add to attention matrix. - - Args: - paddings: binary torch.Tensor of shape [B, T], with 1 denoting padding - token. - dtype: data type of the input. - - Returns: - A torch.Tensor of shape [B, 1, 1, T] ready to add to attention logits. - """ - attention_mask = paddings.detach().clone() - attention_mask = attention_mask[:, None, None, :] # Equivalent to jnp.newaxis - attention_mask *= get_large_negative_number(dtype) - return attention_mask - - -def causal_mask(input_t: torch.Tensor) -> torch.Tensor: - """Computes and returns causal mask. - - Args: - input_t: A torch.Tensor of shape [B, T, D]. - - Returns: - An attention_mask torch.Tensor of shape [1, 1, T, T]. Attention mask has - already been converted to large negative values. - """ - assert input_t.dtype.is_floating_point, input_t.dtype - large_negative_number = get_large_negative_number(input_t.dtype) - t = input_t.shape[1] - col_idx = torch.arange(t).unsqueeze(0).repeat(t, 1) - row_idx = torch.arange(t).unsqueeze(1).repeat(1, t) - mask = (row_idx < col_idx).to(input_t.dtype) * large_negative_number - return ( - mask.unsqueeze(0).unsqueeze(0).to(input_t.device) - ) # Equivalent to jnp.newaxis - - -def merge_masks(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: - """Merges 2 masks. - - logscale mask is expected but 0/1 mask is also fine. - - Args: - a: torch.Tensor of shape [1|B, 1, 1|T, S]. - b: torch.Tensor of shape [1|B, 1, 1|T, S]. - - Returns: - torch.Tensor of shape [1|B, 1, 1|T, S]. - """ - - def expand_t(key_mask): - query_mask = key_mask.transpose(-1, -2) # Equivalent of jnp.transpose - return torch.minimum(query_mask, key_mask) - - if a.shape[2] != b.shape[2]: - if a.shape[2] == 1: - a = expand_t(a) - else: - assert b.shape[2] == 1 - b = expand_t(b) - - assert a.shape[1:] == b.shape[1:], f"a.shape={a.shape}, b.shape={b.shape}." - return torch.minimum(a, b) # Element-wise minimum, similar to jnp.minimum - - -def process_group(key, group, value_name, forecast_context_len): - group = group.tail(forecast_context_len) - return np.array(group[value_name], dtype=np.float32), key - - -def moving_average(arr, window_size): - """Calculates the moving average using NumPy's convolution function.""" - # Pad with zeros to handle initial window positions - arr_padded = np.pad(arr, (window_size - 1, 0), "constant") - smoothed_arr = np.convolve(arr_padded, np.ones(window_size), "valid") / window_size - return [smoothed_arr, arr - smoothed_arr] - - -def freq_map(freq: str): - """Returns the frequency map for the given frequency string.""" - freq = str.upper(freq) - if ( - freq.endswith("H") - or freq.endswith("T") - or freq.endswith("MIN") - or freq.endswith("D") - or freq.endswith("B") - or freq.endswith("U") - ): - return 0 - elif freq.endswith(("W", "M", "MS")): - return 1 - elif freq.endswith("Y") or freq.endswith("Q"): - return 2 - else: - raise ValueError(f"Invalid frequency: {freq}") - - -class ResidualBlock(nn.Module): - """TimesFM residual block.""" - - def __init__( - self, - input_dims, - hidden_dims, - output_dims, - ): - super(ResidualBlock, self).__init__() - self.input_dims = input_dims - self.hidden_dims = hidden_dims - self.output_dims = output_dims - - # Hidden Layer - self.hidden_layer = nn.Sequential( - nn.Linear(input_dims, hidden_dims), - nn.SiLU(), - ) - - # Output Layer - self.output_layer = nn.Linear(hidden_dims, output_dims) - # Residual Layer - self.residual_layer = nn.Linear(input_dims, output_dims) - - def forward(self, x): - hidden = self.hidden_layer(x) - output = self.output_layer(hidden) - residual = self.residual_layer(x) - return output + residual - - -class RMSNorm(torch.nn.Module): - """Pax rms norm in pytorch.""" - - def __init__( - self, - dim: int, - eps: float = 1e-6, - add_unit_offset: bool = False, - ): - super().__init__() - self.eps = eps - self.add_unit_offset = add_unit_offset - self.weight = nn.Parameter(torch.zeros(dim)) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()) - if self.add_unit_offset: - output = output * (1 + self.weight.float()) - else: - output = output * self.weight.float() - return output.type_as(x) - - -class TransformerMLP(nn.Module): - """Pax transformer MLP in pytorch.""" - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - ): - super().__init__() - self.gate_proj = nn.Linear(hidden_size, intermediate_size) - self.down_proj = nn.Linear(intermediate_size, hidden_size) - self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6) - - def forward(self, x, paddings=None): - gate_inp = self.layer_norm(x) - gate = self.gate_proj(gate_inp) - gate = F.relu(gate) - outputs = self.down_proj(gate) - if paddings is not None: - outputs = outputs * (1.0 - paddings[:, :, None]) - return outputs + x - - -class TimesFMAttention(nn.Module): - """Implements the attention used in TimesFM.""" - - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - ): - super().__init__() - - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - self.hidden_size = hidden_size - self.head_dim = head_dim - - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = nn.Parameter( - torch.empty((self.head_dim,), dtype=torch.float32), - ) - - self.qkv_proj = nn.Linear( - self.hidden_size, - (self.num_heads + 2 * self.num_kv_heads) * self.head_dim, - ) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size) - - def _per_dim_scaling(self, query: torch.Tensor) -> torch.Tensor: - # [batch_size, n_local_heads, input_len, head_dim] - r_softplus_0 = 1.442695041 - softplus_func = torch.nn.Softplus() - scale = r_softplus_0 / math.sqrt(self.head_dim) - scale = scale * softplus_func(self.scaling) - return query * scale[None, None, None, :] - - def forward( - self, - hidden_states: torch.Tensor, - mask: torch.Tensor, - kv_write_indices: torch.Tensor | None = None, - kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, - ) -> torch.Tensor: - hidden_states_shape = hidden_states.shape - assert len(hidden_states_shape) == 3 - - print(">>> TimesFMAttention hidden_states_shape", hidden_states_shape) - batch_size, input_len, _ = hidden_states_shape - - qkv = self.qkv_proj(hidden_states) - xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - xq = xq.view(batch_size, -1, self.num_heads, self.head_dim) - xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim) - xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim) - xq = self._per_dim_scaling(xq) - - # Write new kv cache. - # [batch_size, input_len, n_local_kv_heads, head_dim] - if kv_cache is not None and kv_write_indices is not None: - k_cache, v_cache = kv_cache - k_cache.index_copy_(1, kv_write_indices, xk) - v_cache.index_copy_(1, kv_write_indices, xv) - - key = k_cache - value = v_cache - else: - key = xk - value = xv - if self.num_kv_heads != self.num_heads: - # [batch_size, max_seq_len, n_local_heads, head_dim] - key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) - value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2) - - # [batch_size, n_local_heads, input_len, head_dim] - q = xq.transpose(1, 2) - # [batch_size, n_local_heads, max_seq_len, head_dim] - k = key.transpose(1, 2) - v = value.transpose(1, 2) - - # [batch_size, n_local_heads, input_len, max_seq_len] - scores = torch.matmul(q, k.transpose(2, 3)) - scores = scores + mask - scores = F.softmax(scores.float(), dim=-1).type_as(q) - - # [batch_size, n_local_heads, input_len, head_dim] - output = torch.matmul(scores, v) - # return scores, output.transpose(1, 2).contiguous() - - # [batch_size, input_len, hidden_dim] - output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) - output = self.o_proj(output) - return scores, output - - -class TimesFMDecoderLayer(nn.Module): - """Transformer layer.""" - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - rms_norm_eps: float = 1e-6, - ): - super().__init__() - self.self_attn = TimesFMAttention( - hidden_size=hidden_size, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - ) - self.mlp = TransformerMLP( - hidden_size=hidden_size, - intermediate_size=intermediate_size, - ) - self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - mask: torch.Tensor, - paddings: torch.Tensor, - kv_write_indices: torch.Tensor | None = None, - kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, - ) -> torch.Tensor: - # Self Attention - print(">>> TimesFMDecoderLayer hidden_states", hidden_states.shape) - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - scores, hidden_states = self.self_attn( - hidden_states=hidden_states, - mask=mask, - kv_write_indices=kv_write_indices, - kv_cache=kv_cache, - ) - hidden_states = residual + hidden_states - - # MLP - hidden_states = self.mlp(hidden_states, paddings=paddings) - - return scores, hidden_states - - -class StackedDecoder(nn.Module): - """Stacked transformer layer.""" - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - num_layers: int, - rms_norm_eps: float = 1e-6, - ): - super().__init__() - - self.layers = nn.ModuleList() - for _ in range(num_layers): - self.layers.append( - TimesFMDecoderLayer( - hidden_size=hidden_size, - intermediate_size=intermediate_size, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - rms_norm_eps=rms_norm_eps, - ) - ) - - def forward( - self, - hidden_states: torch.Tensor, - paddings: torch.Tensor, - kv_write_indices: torch.Tensor | None = None, - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] | None = None, - output_attentions: bool = False, - output_hidden_states: bool = False, - ) -> torch.Tensor: - print(">>> StackedDecoder hidden_states", hidden_states.shape) - padding_mask = convert_paddings_to_mask(paddings, hidden_states.dtype) - atten_mask = causal_mask(hidden_states) - mask = merge_masks(padding_mask, atten_mask) - all_attentions = [] - all_hidden_states = [] - - for i in range(len(self.layers)): - layer = self.layers[i] - kv_cache = kv_caches[i] if kv_caches is not None else None - scores, hidden_states = layer( - hidden_states=hidden_states, - mask=mask, - paddings=paddings, - kv_write_indices=kv_write_indices, - kv_cache=kv_cache, - ) - if output_attentions: - all_attentions.append(scores) - if output_hidden_states: - all_hidden_states.append(hidden_states) - - return hidden_states, all_attentions, all_hidden_states - - -class PositionalEmbedding(torch.nn.Module): - """Generates position embedding for a given 1-d sequence. - - Attributes: - min_timescale: Start of the geometric index. Determines the periodicity of - the added signal. - max_timescale: End of the geometric index. Determines the frequency of the - added signal. - embedding_dims: Dimension of the embedding to be generated. - """ - - def __init__( - self, - embedding_dims: int, - min_timescale: int = 1, - max_timescale: int = 10_000, - ) -> None: - super().__init__() - self.min_timescale = min_timescale - self.max_timescale = max_timescale - self.embedding_dims = embedding_dims - - def forward(self, seq_length=None, position=None): - """Generates a Tensor of sinusoids with different frequencies. - - Args: - seq_length: an optional Python int defining the output sequence length. - if the `position` argument is specified. - position: [B, seq_length], optional position for each token in the - sequence, only required when the sequence is packed. - - Returns: - [B, seqlen, D] if `position` is specified, else [1, seqlen, D] - """ - if position is None: - assert seq_length is not None - # [1, seqlen] - position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0) - else: - assert position.ndim == 2, position.shape - - num_timescales = self.embedding_dims // 2 - log_timescale_increment = math.log( - float(self.max_timescale) / float(self.min_timescale) - ) / max(num_timescales - 1, 1) - inv_timescales = self.min_timescale * torch.exp( - torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment - ) - scaled_time = position.unsqueeze(2) * inv_timescales.unsqueeze(0).unsqueeze(0) - signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) - # Padding to ensure correct embedding dimension - signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) - return signal diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index c928cf9aca8e..c6d8f932730b 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -18,7 +18,6 @@ from typing import List import numpy as np -import torch from transformers import TimesFMConfig, is_torch_available from transformers.testing_utils import ( From c55088da39728c7a7b9d7d5dace781005d9a4ff0 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 28 Nov 2024 20:32:02 +0100 Subject: [PATCH 107/242] add intermediate_size and initialize with config --- .../models/timesfm/configuration_timesfm.py | 10 ++- .../models/timesfm/modeling_timesfm.py | 79 ++++--------------- tests/models/timesfm/test_modeling_timesfm.py | 3 + 3 files changed, 26 insertions(+), 66 deletions(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 0ff463ba270d..26cf828f0da9 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -47,6 +47,8 @@ class TimesFMConfig(PretrainedConfig): Number of Transformer layers. model_dim (`int`, *optional*, defaults to 1280): Size of the hidden layers in the feed-forward networks. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. head_dim (`int`, *optional*, defaults to 80): Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will be defined as `num_heads * head_dim`. @@ -69,12 +71,14 @@ class TimesFMConfig(PretrainedConfig): initializer_factor (`float`, *optional*, defaults to 1.0): A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing). + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. """ model_type = "timesfm" keys_to_ignore_at_inference = [] attribute_map = { - "hidden_size": "hidden_size", + "hidden_size": "model_dim", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers", } @@ -88,6 +92,7 @@ def __init__( freq_size: int = 3, num_layers: int = 20, model_dim: int = 1280, + intermediate_size: int = 1280, head_dim: int = 80, num_heads: int = 16, dropout_rate: float = 0.1, @@ -98,6 +103,7 @@ def __init__( use_positional_embedding: bool = True, batch_size: int = 32, initializer_factor: float = 1.0, + attention_dropout: float = 0.0, **kwargs, ): self.patch_len = patch_len @@ -107,6 +113,7 @@ def __init__( self.pad_val = pad_val self.freq_size = freq_size self.model_dim = model_dim + self.intermediate_size = intermediate_size self.head_dim = head_dim self.num_layers = num_layers self.num_heads = num_heads @@ -116,6 +123,7 @@ def __init__( self.use_positional_embedding = use_positional_embedding self.batch_size = batch_size self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout super().__init__( is_encoder_decoder=self.is_encoder_decoder, diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 9e0959b8e12f..7abfbb9f5ee3 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -181,23 +181,16 @@ def forward(self, seq_length=None, position=None): class TimesFMAttention(nn.Module): """Implements the attention used in TimesFM.""" - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - ): + def __init__(self, config: TimesFMConfig): super().__init__() - - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads + self.num_heads = config.num_heads + self.num_kv_heads = config.num_heads assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - self.hidden_size = hidden_size - self.head_dim = head_dim + self.hidden_size = config.model_dim + self.head_dim = config.head_dim self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -274,33 +267,17 @@ def forward( # [batch_size, input_len, hidden_dim] output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) output = self.o_proj(output) - return scores, output + return output, scores class TimesFMDecoderLayer(nn.Module): """Transformer layer.""" - def __init__( - self, - hidden_size: int, - intermediate_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - rms_norm_eps: float = 1e-6, - ): + def __init__(self, config: TimesFMConfig): super().__init__() - self.self_attn = TimesFMAttention( - hidden_size=hidden_size, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - ) - self.mlp = TimesFMTransformerMLP( - hidden_size=hidden_size, - intermediate_size=intermediate_size, - ) - self.input_layernorm = TimesFMRMSNorm(hidden_size, eps=rms_norm_eps) + self.self_attn = TimesFMAttention(config) + self.mlp = TimesFMTransformerMLP(config.model_dim, config.intermediate_size) + self.input_layernorm = TimesFMRMSNorm(config.model_dim, eps=config.rms_norm_eps) def forward( self, @@ -313,7 +290,7 @@ def forward( # Self Attention residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - scores, hidden_states = self.self_attn( + hidden_states, scores = self.self_attn( hidden_states=hidden_states, mask=mask, kv_write_indices=kv_write_indices, @@ -330,30 +307,10 @@ def forward( class TimesFMStackedDecoder(nn.Module): """Stacked transformer layer.""" - def __init__( - self, - hidden_size: int, - intermediate_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - num_layers: int, - rms_norm_eps: float = 1e-6, - ): + def __init__(self, config: TimesFMConfig): super().__init__() - self.layers = nn.ModuleList() - for _ in range(num_layers): - self.layers.append( - TimesFMDecoderLayer( - hidden_size=hidden_size, - intermediate_size=intermediate_size, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - rms_norm_eps=rms_norm_eps, - ) - ) + self.layers = nn.ModuleList([TimesFMDecoderLayer(config) for _ in range(config.num_layers)]) def forward( self, @@ -602,15 +559,7 @@ def __init__(self, config: TimesFMConfig): output_dims=config.horizon_len * (1 + len(config.quantiles)), hidden_dims=config.model_dim, ) - self.stacked_transformer = TimesFMStackedDecoder( - hidden_size=self.config.model_dim, - intermediate_size=self.config.model_dim, - num_heads=self.config.num_heads, - num_kv_heads=self.config.num_heads, - head_dim=self.config.head_dim, - num_layers=self.config.num_layers, - rms_norm_eps=self.config.rms_norm_eps, - ) + self.stacked_transformer = TimesFMStackedDecoder(config=config) if self.config.use_positional_embedding: self.position_emb = TimesFMPositionalEmbedding( embedding_dims=self.config.model_dim, diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index c6d8f932730b..394bea5e1b3b 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -56,6 +56,7 @@ def __init__( freq_size: int = 3, num_layers: int = 4, model_dim: int = 128, + intermediate_size: int = 1280, head_dim: int = 16, num_heads: int = 4, dropout_rate: float = 0.1, @@ -76,6 +77,7 @@ def __init__( self.pad_val = pad_val self.freq_size = freq_size self.model_dim = model_dim + self.intermediate_size = intermediate_size self.head_dim = head_dim self.num_hidden_layers = num_layers self.num_attention_heads = num_heads @@ -103,6 +105,7 @@ def get_config(self): pad_val=self.pad_val, freq_size=self.freq_size, model_dim=self.model_dim, + intermediate_size=self.intermediate_size, head_dim=self.head_dim, num_layers=self.num_hidden_layers, num_heads=self.num_attention_heads, From fd270d980f5c3fff2e518202540a076b14b28d1c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 28 Nov 2024 20:37:58 +0100 Subject: [PATCH 108/242] initial documentation --- docs/source/en/model_doc/timesfm.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/en/model_doc/timesfm.md b/docs/source/en/model_doc/timesfm.md index 9acc824f9e0f..1ae971603246 100644 --- a/docs/source/en/model_doc/timesfm.md +++ b/docs/source/en/model_doc/timesfm.md @@ -18,19 +18,19 @@ rendered properly in your Markdown viewer. ## Overview -The TimesFM model was proposed in []() by . - +TimesFM (Time Series Foundation Model) is a pretrained time-series foundation model proposed in [A decoder-only foundation model for time-series forecasting](https://huggingface.co/papers/2310.10688) by Abhimanyu Das, Weihao Kong, Rajat Sen, and Yichen Zhou. It is a decoder only model that uses non-overlapping patches of time-series data as input and outputs some output patch length prediction in an autoregressive fashion. + The abstract from the paper is the following: -** +*Motivated by recent advances in large language models for Natural Language Processing (NLP), we design a time-series foundation model for forecasting whose out-of-the-box zero-shot performance on a variety of public datasets comes close to the accuracy of state-of-the-art supervised forecasting models for each individual dataset. Our model is based on pretraining a patched-decoder style attention model on a large time-series corpus, and can work well across different forecasting history lengths, prediction lengths and temporal granularities.* Tips: This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). -The original code can be found [here](). +The original code can be found [here](https://github.com/google-research/timesfm). ## TimesFMConfig From 9bb5a49755f29719e50ed21e079ae64779a3705a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 28 Nov 2024 20:49:06 +0100 Subject: [PATCH 109/242] rename mask to attention_mask --- src/transformers/models/timesfm/modeling_timesfm.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 7abfbb9f5ee3..ca3b41caf6e5 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -215,7 +215,7 @@ def _per_dim_scaling(self, query: torch.Tensor) -> torch.Tensor: def forward( self, hidden_states: torch.Tensor, - mask: torch.Tensor, + attention_mask: torch.Tensor | None = None, kv_write_indices: torch.Tensor | None = None, kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: @@ -257,7 +257,10 @@ def forward( # [batch_size, n_local_heads, input_len, max_seq_len] scores = torch.matmul(q, k.transpose(2, 3)) - scores = scores + mask + + if attention_mask is not None: + scores = scores + attention_mask + scores = F.softmax(scores.float(), dim=-1).type_as(q) # [batch_size, n_local_heads, input_len, head_dim] @@ -282,7 +285,7 @@ def __init__(self, config: TimesFMConfig): def forward( self, hidden_states: torch.Tensor, - mask: torch.Tensor, + attention_mask: torch.Tensor, paddings: torch.Tensor, kv_write_indices: torch.Tensor | None = None, kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, @@ -292,7 +295,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) hidden_states, scores = self.self_attn( hidden_states=hidden_states, - mask=mask, + attention_mask=attention_mask, kv_write_indices=kv_write_indices, kv_cache=kv_cache, ) @@ -332,7 +335,7 @@ def forward( kv_cache = kv_caches[i] if kv_caches is not None else None scores, hidden_states = layer( hidden_states=hidden_states, - mask=mask, + attention_mask=mask, paddings=paddings, kv_write_indices=kv_write_indices, kv_cache=kv_cache, From 5376dd770b2e257b4328c44b4189599a56a69081 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 28 Nov 2024 21:25:02 +0100 Subject: [PATCH 110/242] smaller tests --- .../models/timesfm/modeling_timesfm.py | 27 +++++++++++++------ tests/models/timesfm/test_modeling_timesfm.py | 10 +++---- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index ca3b41caf6e5..a28f9dea50e8 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -218,7 +218,8 @@ def forward( attention_mask: torch.Tensor | None = None, kv_write_indices: torch.Tensor | None = None, kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, - ) -> torch.Tensor: + output_attentions: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: hidden_states_shape = hidden_states.shape assert len(hidden_states_shape) == 3 @@ -270,6 +271,10 @@ def forward( # [batch_size, input_len, hidden_dim] output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) output = self.o_proj(output) + + if output_attentions: + scores = None + return output, scores @@ -323,7 +328,7 @@ def forward( kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] | None = None, output_attentions: bool = False, output_hidden_states: bool = False, - ) -> torch.Tensor: + ) -> BaseModelOutput: padding_mask = timesfm_convert_paddings_to_mask(paddings, hidden_states.dtype) atten_mask = timesfm_causal_mask(hidden_states) mask = timesfm_merge_masks(padding_mask, atten_mask) @@ -345,7 +350,11 @@ def forward( if output_hidden_states: all_hidden_states.append(hidden_states) - return hidden_states, all_attentions, all_hidden_states + return BaseModelOutput( + last_hidden_state=hidden_states, + attentions=all_attentions, + hidden_states=all_hidden_states, + ) # Move utility functions here @@ -664,7 +673,7 @@ def forward( freq: torch.Tensor, output_attentions: bool = False, output_hidden_states: bool = False, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: num_outputs = len(self.config.quantiles) + 1 model_input, patched_padding, stats, _ = self._preprocess_input( input_ts=input_ts, @@ -673,17 +682,19 @@ def forward( f_emb = self.freq_emb(freq) # B x 1 x D model_input += f_emb - model_output, all_attentions, all_hidden_states = self.stacked_transformer( + transformer_output = self.stacked_transformer( model_input, patched_padding, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if output_hidden_states: - all_hidden_states = [model_input] + all_hidden_states + all_hidden_states = [model_input] + transformer_output.hidden_states + else: + all_hidden_states = None - output_ts = self._postprocess_output(model_output, num_outputs, stats) - return output_ts, all_attentions, all_hidden_states + output_ts = self._postprocess_output(transformer_output.last_hidden_state, num_outputs, stats) + return output_ts, transformer_output.attentions, all_hidden_states def decode( self, diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index 394bea5e1b3b..2b303d12d1ac 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -54,11 +54,11 @@ def __init__( context_len: int = 512, horizon_len: int = 128, freq_size: int = 3, - num_layers: int = 4, - model_dim: int = 128, - intermediate_size: int = 1280, - head_dim: int = 16, - num_heads: int = 4, + num_layers: int = 1, + model_dim: int = 16, + intermediate_size: int = 32, + head_dim: int = 2, + num_heads: int = 2, dropout_rate: float = 0.1, tolerance: float = 1e-6, rms_norm_eps: float = 1e-6, From 8edb51e41bca5b919302a8eeb391b5d59946cc73 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 28 Nov 2024 21:36:03 +0100 Subject: [PATCH 111/242] fixup --- src/transformers/models/auto/configuration_auto.py | 4 ++-- src/transformers/models/auto/modeling_auto.py | 4 ++-- tests/models/timesfm/test_modeling_timesfm.py | 5 +---- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index e906ed233747..c5c8eab382a9 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -291,11 +291,11 @@ ("swinv2", "Swinv2Config"), ("switch_transformers", "SwitchTransformersConfig"), ("t5", "T5Config"), - ("timesfm", "TimesFMConfig"), ("table-transformer", "TableTransformerConfig"), ("tapas", "TapasConfig"), ("textnet", "TextNetConfig"), ("time_series_transformer", "TimeSeriesTransformerConfig"), + ("timesfm", "TimesFMConfig"), ("timesformer", "TimesformerConfig"), ("timm_backbone", "TimmBackboneConfig"), ("timm_wrapper", "TimmWrapperConfig"), @@ -640,13 +640,13 @@ ("swinv2", "Swin Transformer V2"), ("switch_transformers", "SwitchTransformers"), ("t5", "T5"), - ("timesfm", "TimesFM"), ("t5v1.1", "T5v1.1"), ("table-transformer", "Table Transformer"), ("tapas", "TAPAS"), ("tapex", "TAPEX"), ("textnet", "TextNet"), ("time_series_transformer", "Time Series Transformer"), + ("timesfm", "TimesFM"), ("timesformer", "TimeSformer"), ("timm_backbone", "TimmBackbone"), ("timm_wrapper", "TimmWrapperModel"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index d6fef140ed71..6a954bf2ce95 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -267,11 +267,11 @@ ("swinv2", "Swinv2Model"), ("switch_transformers", "SwitchTransformersModel"), ("t5", "T5Model"), - ("timesfm", "TimesFMModel"), ("table-transformer", "TableTransformerModel"), ("tapas", "TapasModel"), ("textnet", "TextNetModel"), ("time_series_transformer", "TimeSeriesTransformerModel"), + ("timesfm", "TimesFMModel"), ("timesformer", "TimesformerModel"), ("timm_backbone", "TimmBackbone"), ("timm_wrapper", "TimmWrapperModel"), @@ -380,8 +380,8 @@ ("squeezebert", "SqueezeBertForMaskedLM"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), - ("timesfm", "TimesFMModel"), ("tapas", "TapasForMaskedLM"), + ("timesfm", "TimesFMModel"), ("transfo-xl", "TransfoXLLMHeadModel"), ("tvlt", "TvltForPreTraining"), ("unispeech", "UniSpeechForPreTraining"), diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index 2b303d12d1ac..71700a1e8102 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -40,7 +40,6 @@ if is_torch_available(): - from transformers import ( TimesFMModel, ) @@ -151,9 +150,7 @@ def prepare_config_and_inputs_for_common(self): @require_torch -class TimesFMModelTest( - ModelTesterMixin, unittest.TestCase -): +class TimesFMModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (TimesFMModel,) if is_torch_available() else () all_generative_model_classes = (TimesFMModel,) if is_torch_available() else () all_parallelizable_model_classes = () From e8e31cd7c8c737e6906bfa1cffca4a2b4c85b270 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 28 Nov 2024 21:37:42 +0100 Subject: [PATCH 112/242] fix copies --- docs/source/en/index.md | 1 + .../models/timesfm/configuration_timesfm.py | 2 +- src/transformers/utils/dummy_pt_objects.py | 14 ++++++++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/docs/source/en/index.md b/docs/source/en/index.md index a6961b06a47b..3971e67557b1 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -340,6 +340,7 @@ Flax), PyTorch, and/or TensorFlow. | [TAPEX](model_doc/tapex) | ✅ | ✅ | ✅ | | [TextNet](model_doc/textnet) | ✅ | ❌ | ❌ | | [Time Series Transformer](model_doc/time_series_transformer) | ✅ | ❌ | ❌ | +| [TimesFM](model_doc/timesfm) | ✅ | ❌ | ❌ | | [TimeSformer](model_doc/timesformer) | ✅ | ❌ | ❌ | | [TimmWrapperModel](model_doc/timm_wrapper) | ✅ | ❌ | ❌ | | [Trajectory Transformer](model_doc/trajectory_transformer) | ✅ | ❌ | ❌ | diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 26cf828f0da9..aa6a64e69bce 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -47,7 +47,7 @@ class TimesFMConfig(PretrainedConfig): Number of Transformer layers. model_dim (`int`, *optional*, defaults to 1280): Size of the hidden layers in the feed-forward networks. - intermediate_size (`int`, *optional*, defaults to 11008): + intermediate_size (`int`, *optional*, defaults to 1280): Dimension of the MLP representations. head_dim (`int`, *optional*, defaults to 80): Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index d409238588d0..74cc0293e95a 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -9487,6 +9487,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class TimesFMModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class TimesFMPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class TimesformerForVideoClassification(metaclass=DummyObject): _backends = ["torch"] From 5b184400cebfab647726fe8ebe7b1af9996083c0 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 29 Nov 2024 12:54:43 +0100 Subject: [PATCH 113/242] move to time series section --- docs/source/en/_toctree.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 330ac4a83b03..ab4ab0840579 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -615,8 +615,6 @@ title: T5v1.1 - local: model_doc/tapex title: TAPEX - - local: model_doc/timesfm - title: TimesFM - local: model_doc/transfo-xl title: Transformer XL - local: model_doc/ul2 @@ -1015,6 +1013,8 @@ title: PatchTSMixer - local: model_doc/patchtst title: PatchTST + - local: model_doc/timesfm + title: TimesFM - local: model_doc/time_series_transformer title: Time Series Transformer title: Time series models From 5ebeec2bf18af8d43565b18b18eebf16058d21ff Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 29 Nov 2024 12:56:24 +0100 Subject: [PATCH 114/242] sort docs --- docs/source/en/_toctree.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index ab4ab0840579..ed213d04c192 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1013,10 +1013,10 @@ title: PatchTSMixer - local: model_doc/patchtst title: PatchTST - - local: model_doc/timesfm - title: TimesFM - local: model_doc/time_series_transformer title: Time Series Transformer + - local: model_doc/timesfm + title: TimesFM title: Time series models - isExpanded: false sections: From f810125419e3f09d7e61dd3f71e2758c7b1f71c3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 29 Nov 2024 12:58:46 +0100 Subject: [PATCH 115/242] isort fix --- src/transformers/__init__.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 5cee1d88630f..bf2932454f8a 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -813,7 +813,6 @@ "models.swinv2": ["Swinv2Config"], "models.switch_transformers": ["SwitchTransformersConfig"], "models.t5": ["T5Config"], - "models.timesfm": ["TimesFMConfig"], "models.table_transformer": ["TableTransformerConfig"], "models.tapas": [ "TapasConfig", @@ -821,6 +820,7 @@ ], "models.textnet": ["TextNetConfig"], "models.time_series_transformer": ["TimeSeriesTransformerConfig"], + "models.timesfm": ["TimesFMConfig"], "models.timesformer": ["TimesformerConfig"], "models.timm_backbone": ["TimmBackboneConfig"], "models.timm_wrapper": ["TimmWrapperConfig"], @@ -3708,12 +3708,6 @@ "load_tf_weights_in_t5", ] ) - _import_structure["models.timesfm"].extend( - [ - "TimesFMModel", - "TimesFMPreTrainedModel", - ] - ) _import_structure["models.table_transformer"].extend( [ "TableTransformerForObjectDetection", @@ -3746,6 +3740,12 @@ "TimeSeriesTransformerPreTrainedModel", ] ) + _import_structure["models.timesfm"].extend( + [ + "TimesFMModel", + "TimesFMPreTrainedModel", + ] + ) _import_structure["models.timesformer"].extend( [ "TimesformerForVideoClassification", From 7e5921cf892315af8213f398e96084045c7e73f1 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 29 Nov 2024 13:06:34 +0100 Subject: [PATCH 116/242] batch_size is not a configuration --- .../models/timesfm/configuration_timesfm.py | 4 --- .../models/timesfm/modeling_timesfm.py | 29 ++----------------- tests/models/timesfm/test_modeling_timesfm.py | 3 -- 3 files changed, 3 insertions(+), 33 deletions(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index aa6a64e69bce..45691cd2f46a 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -66,8 +66,6 @@ class TimesFMConfig(PretrainedConfig): The value used to pad the predictions. use_positional_embedding (`bool`, *optional*, defaults to `True`): Whether to add positional embeddings. - batch_size (`int`, *optional*, defaults to 32): - The batch size. initializer_factor (`float`, *optional*, defaults to 1.0): A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing). @@ -101,7 +99,6 @@ def __init__( quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], pad_val: float = 1123581321.0, use_positional_embedding: bool = True, - batch_size: int = 32, initializer_factor: float = 1.0, attention_dropout: float = 0.0, **kwargs, @@ -121,7 +118,6 @@ def __init__( self.tolerance = tolerance self.rms_norm_eps = rms_norm_eps self.use_positional_embedding = use_positional_embedding - self.batch_size = batch_size self.initializer_factor = initializer_factor self.attention_dropout = attention_dropout diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index a28f9dea50e8..a2b2d6616d4c 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -788,14 +788,6 @@ def __init__(self, config: TimesFMConfig): self.context_len = config.context_len self.horizon_len = config.horizon_len - self.input_patch_len = config.patch_len - self.output_patch_len = config.horizon_len - self.num_layers = config.num_layers - self.model_dims = config.model_dim - self.quantiles = config.quantiles - self.num_heads = config.num_heads - self.batch_size = config.batch_size - self._horizon_start = self.context_len - self.input_patch_len # Initialize weights and apply final processing self.post_init() @@ -905,24 +897,9 @@ def forward( input_ts, input_padding, inp_freq = self._preprocess(inputs, freq) - input_ts_in = torch.from_numpy( - np.array( - input_ts, - dtype=np.float32, - ) - ) - input_padding_in = torch.from_numpy( - np.array( - input_padding, - dtype=np.float32, - ) - ) - inp_freq_in = torch.from_numpy( - np.array( - inp_freq, - dtype=np.int32, - ) - ).long() + input_ts_in = torch.from_numpy(np.array(input_ts, dtype=np.float32)) + input_padding_in = torch.from_numpy(np.array(input_padding, dtype=np.float32)) + inp_freq_in = torch.from_numpy(np.array(inp_freq, dtype=np.int32)).long() mean_outputs, full_outputs, last_hidden_state, all_attentions, all_hidden_states = self.decoder.decode( input_ts=input_ts_in, paddings=input_padding_in, diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index 71700a1e8102..8d43a5b1498b 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -64,7 +64,6 @@ def __init__( quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], pad_val: float = 1123581321.0, use_positional_embedding: bool = True, - batch_size: int = 32, initializer_factor: float = 0.0, is_training: bool = False, ): @@ -84,7 +83,6 @@ def __init__( self.tolerance = tolerance self.rms_norm_eps = rms_norm_eps self.use_positional_embedding = use_positional_embedding - self.batch_size = batch_size self.initializer_factor = initializer_factor self.is_training = is_training @@ -112,7 +110,6 @@ def get_config(self): tolerance=self.tolerance, rms_norm_eps=self.rms_norm_eps, use_positional_embedding=self.use_positional_embedding, - batch_size=self.batch_size, initializer_factor=self.initializer_factor, ) From 906d6a89732ce8d1706487669aa5dead2dc57b2d Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 30 Nov 2024 18:28:31 +0100 Subject: [PATCH 117/242] rename to TimesFMModelForPrediction --- docs/source/en/model_doc/timesfm.md | 28 +-- src/transformers/__init__.py | 6 +- src/transformers/models/auto/modeling_auto.py | 4 +- src/transformers/models/timesfm/__init__.py | 8 +- .../models/timesfm/configuration_timesfm.py | 10 +- .../models/timesfm/modeling_timesfm.py | 216 ++++++++++-------- src/transformers/utils/dummy_pt_objects.py | 9 +- tests/models/timesfm/test_modeling_timesfm.py | 16 +- 8 files changed, 149 insertions(+), 148 deletions(-) diff --git a/docs/source/en/model_doc/timesfm.md b/docs/source/en/model_doc/timesfm.md index 1ae971603246..76cf1f8afef2 100644 --- a/docs/source/en/model_doc/timesfm.md +++ b/docs/source/en/model_doc/timesfm.md @@ -37,34 +37,14 @@ The original code can be found [here](https://github.com/google-research/timesfm [[autodoc]] TimesFMConfig -## TimesFMModel +## TimesFMDecoder -[[autodoc]] TimesFMModel +[[autodoc]] TimesFMDecoder - forward -## TimesFMForConditionalGeneration +## TimesFMModelForPrediction -[[autodoc]] TimesFMForConditionalGeneration - - forward - -## TimesFMEncoderModel - -[[autodoc]] TimesFMEncoderModel - - forward - -## TimesFMForSequenceClassification - -[[autodoc]] TimesFMForSequenceClassification - - forward - -## TimesFMForTokenClassification - -[[autodoc]] TimesFMForTokenClassification - - forward - -## TimesFMForQuestionAnswering - -[[autodoc]] TimesFMForQuestionAnswering +[[autodoc]] TimesFMModelForPrediction - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index bf2932454f8a..a38dc203485d 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3742,7 +3742,8 @@ ) _import_structure["models.timesfm"].extend( [ - "TimesFMModel", + "TimesFMModelForPrediction", + "TimesFMDecoder", "TimesFMPreTrainedModel", ] ) @@ -8454,7 +8455,8 @@ TimeSeriesTransformerPreTrainedModel, ) from .models.timesfm import ( - TimesFMModel, + TimesFMDecoder, + TimesFMModelForPrediction, TimesFMPreTrainedModel, ) from .models.timesformer import ( diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 6a954bf2ce95..4ec04c736771 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -271,7 +271,7 @@ ("tapas", "TapasModel"), ("textnet", "TextNetModel"), ("time_series_transformer", "TimeSeriesTransformerModel"), - ("timesfm", "TimesFMModel"), + ("timesfm", "TimesFMModelForPrediction"), ("timesformer", "TimesformerModel"), ("timm_backbone", "TimmBackbone"), ("timm_wrapper", "TimmWrapperModel"), @@ -381,7 +381,7 @@ ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), ("tapas", "TapasForMaskedLM"), - ("timesfm", "TimesFMModel"), + ("timesfm", "TimesFMModelForPrediction"), ("transfo-xl", "TransfoXLLMHeadModel"), ("tvlt", "TvltForPreTraining"), ("unispeech", "UniSpeechForPreTraining"), diff --git a/src/transformers/models/timesfm/__init__.py b/src/transformers/models/timesfm/__init__.py index 6592a5b1620e..51028a860782 100644 --- a/src/transformers/models/timesfm/__init__.py +++ b/src/transformers/models/timesfm/__init__.py @@ -30,7 +30,8 @@ pass else: _import_structure["modeling_timesfm"] = [ - "TimesFMModel", + "TimesFMModelForPrediction", + "TimesFMDecoder", "TimesFMPreTrainedModel", ] @@ -43,10 +44,7 @@ except OptionalDependencyNotAvailable: pass else: - from .modeling_timesfm import ( - TimesFMModel, - TimesFMPreTrainedModel, - ) + from .modeling_timesfm import TimesFMDecoder, TimesFMModelForPrediction, TimesFMPreTrainedModel else: import sys diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 45691cd2f46a..eb135c03c968 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -26,7 +26,7 @@ class TimesFMConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`TimesFMModel`] or a [`TFTimesFMModel`]. It is used to + This is the configuration class to store the configuration of a [`TimesFMModelForPrediction`] or a [`TFTimesFMDecoder`]. It is used to instantiate a TimesFM model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the TimesFM [google/timesfm-1.0-200m](https://huggingface.co/google/timesfm-1.0-200m) architecture. @@ -54,8 +54,6 @@ class TimesFMConfig(PretrainedConfig): be defined as `num_heads * head_dim`. num_heads (`int`, *optional*, defaults to 16): Number of attention heads for each attention layer in the Transformer encoder. - dropout_rate (`float`, *optional*, defaults to 0.1): - The ratio for all dropout layers. tolerance (`float`, *optional*, defaults to 1e-06): The tolerance for the quantile loss. rms_norm_eps (`float`, *optional*, defaults to 1e-06): @@ -69,8 +67,6 @@ class TimesFMConfig(PretrainedConfig): initializer_factor (`float`, *optional*, defaults to 1.0): A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing). - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. """ model_type = "timesfm" @@ -93,14 +89,12 @@ def __init__( intermediate_size: int = 1280, head_dim: int = 80, num_heads: int = 16, - dropout_rate: float = 0.1, tolerance: float = 1e-6, rms_norm_eps: float = 1e-6, quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], pad_val: float = 1123581321.0, use_positional_embedding: bool = True, initializer_factor: float = 1.0, - attention_dropout: float = 0.0, **kwargs, ): self.patch_len = patch_len @@ -114,12 +108,10 @@ def __init__( self.head_dim = head_dim self.num_layers = num_layers self.num_heads = num_heads - self.dropout_rate = dropout_rate self.tolerance = tolerance self.rms_norm_eps = rms_norm_eps self.use_positional_embedding = use_positional_embedding self.initializer_factor = initializer_factor - self.attention_dropout = attention_dropout super().__init__( is_encoder_decoder=self.is_encoder_decoder, diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index a2b2d6616d4c..f2134e90e463 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -37,9 +37,15 @@ @dataclass -class TimesFMOutput(BaseModelOutput): - mean_predictions: np.ndarray = None - full_predictions: np.ndarray = None +class TimesFMDecoderOutput(BaseModelOutput): + loc: np.ndarray | None = None + scale: np.ndarray | None = None + + +@dataclass +class TimesFMOutputForPrediction(BaseModelOutput): + mean_predictions: np.ndarray | None = None + full_predictions: np.ndarray | None = None class TimesFMTransformerMLP(nn.Module): @@ -553,8 +559,8 @@ def _init_weights(self, module): pass -class PatchedTimeSeriesDecoder(TimesFMPreTrainedModel): - """Patched time-series decoder.""" +class TimesFMDecoder(TimesFMPreTrainedModel): + """Patched time-series decoder without any specific output layer.""" def __init__(self, config: TimesFMConfig): super().__init__(config) @@ -566,11 +572,6 @@ def __init__(self, config: TimesFMConfig): hidden_dims=config.model_dim, ) self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.model_dim) - self.horizon_ff_layer = TimesFMResidualBlock( - input_dims=config.model_dim, - output_dims=config.horizon_len * (1 + len(config.quantiles)), - hidden_dims=config.model_dim, - ) self.stacked_transformer = TimesFMStackedDecoder(config=config) if self.config.use_positional_embedding: self.position_emb = TimesFMPositionalEmbedding( @@ -600,11 +601,6 @@ def _forward_transform( ) return outputs, (mu, sigma) - def _reverse_transform(self, outputs: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: - """Output is of shape [B, N, P, Q].""" - mu, sigma = stats - return outputs * sigma[:, None, None, None] + mu[:, None, None, None] - def _preprocess_input( self, input_ts: torch.Tensor, @@ -649,23 +645,6 @@ def _preprocess_input( return model_input, patched_padding, stats, patched_inputs - def _postprocess_output( - self, - model_output: torch.Tensor, - num_outputs: int, - stats: tuple[torch.Tensor, torch.Tensor], - ) -> torch.Tensor: - """Postprocess output of stacked transformer.""" - - # B x N x (H.Q) - output_ts = self.horizon_ff_layer(model_output) - - # Reshape using view - b, n, _ = output_ts.shape - output_ts = output_ts.view(b, n, self.config.horizon_len, num_outputs) - - return self._reverse_transform(output_ts, stats) - def forward( self, input_ts: torch.Tensor, @@ -673,8 +652,7 @@ def forward( freq: torch.Tensor, output_attentions: bool = False, output_hidden_states: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: - num_outputs = len(self.config.quantiles) + 1 + ) -> TimesFMDecoderOutput: model_input, patched_padding, stats, _ = self._preprocess_input( input_ts=input_ts, input_padding=input_padding, @@ -693,8 +671,96 @@ def forward( else: all_hidden_states = None - output_ts = self._postprocess_output(transformer_output.last_hidden_state, num_outputs, stats) - return output_ts, transformer_output.attentions, all_hidden_states + return TimesFMDecoderOutput( + last_hidden_state=transformer_output.last_hidden_state, + hidden_states=all_hidden_states, + attentions=transformer_output.attentions if output_attentions else None, + loc=stats[0], + scale=stats[1], + ) + + +class TimesFMModelForPrediction(TimesFMPreTrainedModel): + def __init__(self, config: TimesFMConfig): + super().__init__(config) + + self.config = config + self.context_len = config.context_len + self.horizon_len = config.horizon_len + + self.decoder = TimesFMDecoder(config) + + # quantile and mean output + self.horizon_ff_layer = TimesFMResidualBlock( + input_dims=config.model_dim, + output_dims=config.horizon_len * (1 + len(config.quantiles)), + hidden_dims=config.model_dim, + ) + + # Initialize weights and apply final processing + self.post_init() + + def _preprocess(self, inputs: Sequence[np.array], freq: Sequence[int]) -> tuple[np.array, np.array, int]: + """Formats and pads raw inputs to feed into the model. + + This function both pads each time series to match the context length, and + pads the inputs to meet the SPMD shape requirement. + + Args: + inputs: A list of 1d Tensors. Each JTensor is the context time series of + a single forecast task. + freq: list of frequencies + + Returns: + A tuple of: + - the padded input time series to meet the model required context. + - the padding indicator. + - the number of padded examples for SPMD so that each core has the same + number (a multiple of `batch_size`) of examples. + """ + input_ts, input_padding, inp_freq = [], [], [] + + for i, ts in enumerate(inputs): + input_len = ts.shape[0] + padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float) + if input_len < self.context_len: + num_front_pad = self.context_len - input_len + ts = np.concatenate([np.zeros(shape=(num_front_pad,), dtype=float), ts], axis=0) + padding = np.concatenate([np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0) + elif input_len > self.context_len: + ts = ts[-self.context_len :] + padding = padding[-(self.context_len + self.horizon_len) :] + + input_ts.append(ts) + input_padding.append(padding) + inp_freq.append(freq[i]) + + return ( + np.stack(input_ts, axis=0), + np.stack(input_padding, axis=0), + np.array(inp_freq).astype(np.int32).reshape(-1, 1), + ) + + def _postprocess_output( + self, + model_output: torch.Tensor, + stats: tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + """Postprocess output of stacked transformer.""" + + # B x N x (H.Q) + output_ts = self.horizon_ff_layer(model_output) + + # Reshape using view + b, n, _ = output_ts.shape + output_ts = output_ts.view(b, n, self.config.horizon_len, len(self.config.quantiles) + 1) + + return self._reverse_transform(output_ts, stats) + + def _reverse_transform(self, outputs: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + """Output is of shape [B, N, P, Q].""" + mu, sigma = stats + return outputs * sigma[:, None, None, None] + mu[:, None, None, None] def decode( self, @@ -732,6 +798,7 @@ def decode( final_out = input_ts context_len = final_out.shape[1] full_outputs = [] + if paddings.shape[1] != final_out.shape[1] + horizon_len: raise ValueError( "Length of paddings must match length of input + horizon_len:" @@ -744,13 +811,18 @@ def decode( current_padding = paddings[:, 0 : final_out.shape[1]] input_ts = final_out[:, -max_len:] input_padding = current_padding[:, -max_len:] - fprop_outputs, all_attentions, all_hidden_states = self.forward( + decoder_output = self.decoder( input_ts, input_padding, freq, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) + fprop_outputs = self._postprocess_output( + decoder_output.last_hidden_state, + (decoder_output.loc, decoder_output.scale), + ) + if return_forecast_on_context and step_index == 0: # For the first decodings step, collect the model forecast on the # context except the unavailable first input batch forecast. @@ -775,62 +847,12 @@ def decode( # `full_outputs` indexing starts at the forecast horizon. full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0:horizon_len, :] - return full_outputs[:, :, 0], full_outputs, fprop_outputs, all_attentions, all_hidden_states - - -class TimesFMModel(TimesFMPreTrainedModel): - def __init__(self, config: TimesFMConfig): - super().__init__(config) - - self.config = config - - self.decoder = PatchedTimeSeriesDecoder(config) - - self.context_len = config.context_len - self.horizon_len = config.horizon_len - - # Initialize weights and apply final processing - self.post_init() - - def _preprocess(self, inputs: Sequence[np.array], freq: Sequence[int]) -> tuple[np.array, np.array, int]: - """Formats and pads raw inputs to feed into the model. - - This function both pads each time series to match the context length, and - pads the inputs to meet the SPMD shape requirement. - - Args: - inputs: A list of 1d JTensors. Each JTensor is the context time series of - a single forecast task. - freq: list of frequencies - - Returns: - A tuple of: - - the padded input time series to meet the model required context. - - the padding indicator. - - the number of padded examples for SPMD so that each core has the same - number (a multiple of `batch_size`) of examples. - """ - input_ts, input_padding, inp_freq = [], [], [] - - for i, ts in enumerate(inputs): - input_len = ts.shape[0] - padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float) - if input_len < self.context_len: - num_front_pad = self.context_len - input_len - ts = np.concatenate([np.zeros(shape=(num_front_pad,), dtype=float), ts], axis=0) - padding = np.concatenate([np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0) - elif input_len > self.context_len: - ts = ts[-self.context_len :] - padding = padding[-(self.context_len + self.horizon_len) :] - - input_ts.append(ts) - input_padding.append(padding) - inp_freq.append(freq[i]) - return ( - np.stack(input_ts, axis=0), - np.stack(input_padding, axis=0), - np.array(inp_freq).astype(np.int32).reshape(-1, 1), + full_outputs[:, :, 0], + full_outputs, + decoder_output.last_hidden_state, + decoder_output.attentions, + decoder_output.hidden_states, ) def forward( @@ -844,12 +866,12 @@ def forward( output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, - ) -> tuple[np.ndarray, np.ndarray]: + ) -> TimesFMOutputForPrediction: """Forecasts on a list of time series. Args: inputs: list of time series forecast contexts. Each context time series - should be in a format convertible to JTensor by `jnp.array`. + should be in a format convertible to Tensor. freq: frequency of each context time series. 0 for high frequency (default), 1 for medium, and 2 for low. Notice this is different from the `freq` required by `forecast_on_df`. @@ -862,7 +884,7 @@ def forward( have non-negative values. Returns: - A tuple for JTensors: + A tuple for Tensors: - the mean forecast of size (# inputs, # forecast horizon), - the full forecast (mean + quantiles) of size (# inputs, # forecast horizon, 1 + # quantiles). @@ -900,7 +922,7 @@ def forward( input_ts_in = torch.from_numpy(np.array(input_ts, dtype=np.float32)) input_padding_in = torch.from_numpy(np.array(input_padding, dtype=np.float32)) inp_freq_in = torch.from_numpy(np.array(inp_freq, dtype=np.int32)).long() - mean_outputs, full_outputs, last_hidden_state, all_attentions, all_hidden_states = self.decoder.decode( + mean_outputs, full_outputs, last_hidden_state, all_attentions, all_hidden_states = self.decode( input_ts=input_ts_in, paddings=input_padding_in, freq=inp_freq_in, @@ -918,7 +940,7 @@ def forward( full_outputs = torch.maximum(full_outputs, 0.0) if return_dict: - return TimesFMOutput( + return TimesFMOutputForPrediction( last_hidden_state=last_hidden_state, attentions=all_attentions if output_attentions else None, hidden_states=all_hidden_states if output_hidden_states else None, diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 74cc0293e95a..78e387d4f6cc 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -9487,7 +9487,14 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class TimesFMModel(metaclass=DummyObject): +class TimesFMDecoder(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class TimesFMModelForPrediction(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index 8d43a5b1498b..ec7550b043f3 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -40,9 +40,7 @@ if is_torch_available(): - from transformers import ( - TimesFMModel, - ) + from transformers import TimesFMModelForPrediction class TimesFMModelTester: @@ -66,6 +64,7 @@ def __init__( use_positional_embedding: bool = True, initializer_factor: float = 0.0, is_training: bool = False, + batch_size: int = 3, ): self.parent = parent self.patch_len = patch_len @@ -85,6 +84,7 @@ def __init__( self.use_positional_embedding = use_positional_embedding self.initializer_factor = initializer_factor self.is_training = is_training + self.batch_size = batch_size # The size of test input self.seq_length = context_len // patch_len @@ -148,8 +148,8 @@ def prepare_config_and_inputs_for_common(self): @require_torch class TimesFMModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (TimesFMModel,) if is_torch_available() else () - all_generative_model_classes = (TimesFMModel,) if is_torch_available() else () + all_model_classes = (TimesFMModelForPrediction,) if is_torch_available() else () + all_generative_model_classes = (TimesFMModelForPrediction,) if is_torch_available() else () all_parallelizable_model_classes = () fx_compatible = False test_pruning = False @@ -164,7 +164,7 @@ def setUp(self): def test_create_and_run_model(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model = TimesFMModel(config) + model = TimesFMModelForPrediction(config) model.to(torch_device) model.eval() results = model(**inputs_dict) @@ -180,7 +180,7 @@ def test_headmasking(self): # the main input name is `inputs` def test_model_main_input_name(self): - model_signature = inspect.signature(getattr(TimesFMModel, "forward")) + model_signature = inspect.signature(getattr(TimesFMModelForPrediction, "forward")) # The main input is the name of the argument after `self` observed_main_input_name = list(model_signature.parameters.keys())[1] - self.assertEqual(TimesFMModel.main_input_name, observed_main_input_name) + self.assertEqual(TimesFMModelForPrediction.main_input_name, observed_main_input_name) From c30e7488d22a5c345a7b3da95ebdb18dd0489840 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 4 Dec 2024 20:41:04 +0100 Subject: [PATCH 118/242] initial script --- .../convert_timesfm_orignal_to_pytorch.py | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py new file mode 100644 index 000000000000..bf186c06574c --- /dev/null +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py @@ -0,0 +1,76 @@ +import argparse +import os +import shutil + +import timesfm + +from transformers import TimesFMConfig, TimesFMModelForPrediction + + +""" +Sample usage: + +``` +python src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py \ + --output_dir /output/path +``` +""" + + +def write_model(model_path, safe_serialization=True): + os.makedirs(model_path, exist_ok=True) + tmp_model_path = os.path.join(model_path, "tmp") + os.makedirs(tmp_model_path, exist_ok=True) + + tfm = timesfm.TimesFm( + hparams=timesfm.TimesFmHparams( + backend="cpu", + per_core_batch_size=32, + horizon_len=128, + ), + checkpoint=timesfm.TimesFmCheckpoint( + huggingface_repo_id="google/timesfm-1.0-200m-pytorch"), + ) + + timesfm_config = TimesFMConfig( + patch_len=tfm.hparams.input_patch_len, + context_len=tfm.hparams.context_len, + horizon_len=tfm.hparams.horizon_len, + num_layers=tfm.hparams.num_layers, + model_dim=tfm.hparams.model_dims, + intermediate_size=tfm.hparams.model_dims, + head_dim=tfm.hparams.model_dims//tfm.hparams.num_heads, + num_heads=tfm.hparams.num_heads, + ) + timesfm_config.save_pretrained(tmp_model_path) + timesfm_model = TimesFMModelForPrediction(timesfm_config) + + # copy the weights from the original model to the new model making + import pdb; pdb.set_trace() + orignal_model = tfm._model + + + timesfm_model.load_state_dict(tfm.state_dict()) + timesfm_model.save_pretrained(model_path, safe_serialization=safe_serialization) + shutil.rmtree(tmp_model_path) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--output_dir", + required=True, + help="Location to write HF model and tokenizer", + ) + parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.") + args = parser.parse_args() + write_model( + model_path=args.output_dir, + safe_serialization=args.safe_serialization, + ) + + check_outputs(args.output_dir) + + +if __name__ == "__main__": + main() From d7d3a1368b0f5febad0e73c9739a23e532301581 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 5 Dec 2024 09:18:12 +0100 Subject: [PATCH 119/242] add check_outputs --- .../convert_timesfm_orignal_to_pytorch.py | 173 ++++++++++++++++-- 1 file changed, 162 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py index bf186c06574c..ac638e9efe66 100644 --- a/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py @@ -2,7 +2,9 @@ import os import shutil +import numpy as np import timesfm +import torch from transformers import TimesFMConfig, TimesFMModelForPrediction @@ -23,13 +25,12 @@ def write_model(model_path, safe_serialization=True): os.makedirs(tmp_model_path, exist_ok=True) tfm = timesfm.TimesFm( - hparams=timesfm.TimesFmHparams( - backend="cpu", - per_core_batch_size=32, - horizon_len=128, - ), - checkpoint=timesfm.TimesFmCheckpoint( - huggingface_repo_id="google/timesfm-1.0-200m-pytorch"), + hparams=timesfm.TimesFmHparams( + backend="cpu", + per_core_batch_size=32, + horizon_len=128, + ), + checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id="google/timesfm-1.0-200m-pytorch"), ) timesfm_config = TimesFMConfig( @@ -39,22 +40,172 @@ def write_model(model_path, safe_serialization=True): num_layers=tfm.hparams.num_layers, model_dim=tfm.hparams.model_dims, intermediate_size=tfm.hparams.model_dims, - head_dim=tfm.hparams.model_dims//tfm.hparams.num_heads, + head_dim=tfm.hparams.model_dims // tfm.hparams.num_heads, num_heads=tfm.hparams.num_heads, ) timesfm_config.save_pretrained(tmp_model_path) timesfm_model = TimesFMModelForPrediction(timesfm_config) # copy the weights from the original model to the new model making - import pdb; pdb.set_trace() - orignal_model = tfm._model + original_model = tfm._model + + # Map decoder input_ff_layer + timesfm_model.decoder.input_ff_layer.hidden_layer[0].weight.data = original_model.input_ff_layer.hidden_layer[ + 0 + ].weight.data + timesfm_model.decoder.input_ff_layer.hidden_layer[0].bias.data = original_model.input_ff_layer.hidden_layer[ + 0 + ].bias.data + timesfm_model.decoder.input_ff_layer.output_layer.weight.data = ( + original_model.input_ff_layer.output_layer.weight.data + ) + timesfm_model.decoder.input_ff_layer.output_layer.bias.data = original_model.input_ff_layer.output_layer.bias.data + timesfm_model.decoder.input_ff_layer.residual_layer.weight.data = ( + original_model.input_ff_layer.residual_layer.weight.data + ) + timesfm_model.decoder.input_ff_layer.residual_layer.bias.data = ( + original_model.input_ff_layer.residual_layer.bias.data + ) + # Map freq embedding + timesfm_model.decoder.freq_emb.weight.data = original_model.freq_emb.weight.data + + # Map horizon_ff_layer + timesfm_model.horizon_ff_layer.hidden_layer[0].weight.data = original_model.horizon_ff_layer.hidden_layer[ + 0 + ].weight.data + timesfm_model.horizon_ff_layer.hidden_layer[0].bias.data = original_model.horizon_ff_layer.hidden_layer[ + 0 + ].bias.data + timesfm_model.horizon_ff_layer.output_layer.weight.data = original_model.horizon_ff_layer.output_layer.weight.data + timesfm_model.horizon_ff_layer.output_layer.bias.data = original_model.horizon_ff_layer.output_layer.bias.data + timesfm_model.horizon_ff_layer.residual_layer.weight.data = ( + original_model.horizon_ff_layer.residual_layer.weight.data + ) + timesfm_model.horizon_ff_layer.residual_layer.bias.data = original_model.horizon_ff_layer.residual_layer.bias.data + + # Map transformer layers + for i in range(len(timesfm_model.decoder.stacked_transformer.layers)): + # Map attention layers + timesfm_model.decoder.stacked_transformer.layers[ + i + ].self_attn.qkv_proj.weight.data = original_model.stacked_transformer.layers[i].self_attn.qkv_proj.weight.data + timesfm_model.decoder.stacked_transformer.layers[ + i + ].self_attn.qkv_proj.bias.data = original_model.stacked_transformer.layers[i].self_attn.qkv_proj.bias.data + timesfm_model.decoder.stacked_transformer.layers[ + i + ].self_attn.o_proj.weight.data = original_model.stacked_transformer.layers[i].self_attn.o_proj.weight.data + timesfm_model.decoder.stacked_transformer.layers[ + i + ].self_attn.o_proj.bias.data = original_model.stacked_transformer.layers[i].self_attn.o_proj.bias.data + timesfm_model.decoder.stacked_transformer.layers[ + i + ].self_attn.scaling.data = original_model.stacked_transformer.layers[i].self_attn.scaling.data + + # Map MLP layers + timesfm_model.decoder.stacked_transformer.layers[ + i + ].mlp.gate_proj.weight.data = original_model.stacked_transformer.layers[i].mlp.gate_proj.weight.data + timesfm_model.decoder.stacked_transformer.layers[ + i + ].mlp.gate_proj.bias.data = original_model.stacked_transformer.layers[i].mlp.gate_proj.bias.data + timesfm_model.decoder.stacked_transformer.layers[ + i + ].mlp.down_proj.weight.data = original_model.stacked_transformer.layers[i].mlp.down_proj.weight.data + timesfm_model.decoder.stacked_transformer.layers[ + i + ].mlp.down_proj.bias.data = original_model.stacked_transformer.layers[i].mlp.down_proj.bias.data + timesfm_model.decoder.stacked_transformer.layers[ + i + ].mlp.layer_norm.weight.data = original_model.stacked_transformer.layers[i].mlp.layer_norm.weight.data + timesfm_model.decoder.stacked_transformer.layers[ + i + ].mlp.layer_norm.bias.data = original_model.stacked_transformer.layers[i].mlp.layer_norm.bias.data + + # Map layer norms + timesfm_model.decoder.stacked_transformer.layers[ + i + ].input_layernorm.weight.data = original_model.stacked_transformer.layers[i].input_layernorm.weight.data - timesfm_model.load_state_dict(tfm.state_dict()) timesfm_model.save_pretrained(model_path, safe_serialization=safe_serialization) shutil.rmtree(tmp_model_path) +def check_outputs(model_path): + """Compares outputs between original and converted models.""" + print("\nChecking model outputs...") + + # Load original model + tfm = timesfm.TimesFm( + hparams=timesfm.TimesFmHparams( + backend="cpu", + per_core_batch_size=32, + horizon_len=128, + ), + checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id="google/timesfm-1.0-200m-pytorch"), + ) + + # Load converted model + converted_model = TimesFMModelForPrediction.from_pretrained(model_path) + converted_model.eval() # Set to evaluation mode + + # Create test inputs + forecast_input = [ + np.sin(np.linspace(0, 20, 100)), + np.sin(np.linspace(0, 20, 200)), + np.sin(np.linspace(0, 20, 400)), + ] + frequency_input = [0, 1, 2] + + # Get predictions from original model + point_forecast_orig, quantile_forecast_orig = tfm.forecast( + forecast_input, + freq=frequency_input, + ) + + # Get predictions from converted model + with torch.no_grad(): + outputs = converted_model(inputs=forecast_input, freq=frequency_input, return_dict=True) + point_forecast_conv = outputs.mean_predictions.numpy() + quantile_forecast_conv = outputs.full_predictions.numpy() + + # Compare outputs + point_forecast_diff = np.abs(point_forecast_orig - point_forecast_conv) + quantile_forecast_diff = np.abs(quantile_forecast_orig - quantile_forecast_conv) + + max_point_diff = point_forecast_diff.max() + mean_point_diff = point_forecast_diff.mean() + max_quantile_diff = quantile_forecast_diff.max() + mean_quantile_diff = quantile_forecast_diff.mean() + + print("\nOutput comparison:") + print(f"Point forecast - Max difference: {max_point_diff:.6f}") + print(f"Point forecast - Mean difference: {mean_point_diff:.6f}") + print(f"Quantile forecast - Max difference: {max_quantile_diff:.6f}") + print(f"Quantile forecast - Mean difference: {mean_quantile_diff:.6f}") + + # Define acceptable thresholds + POINT_THRESHOLD = 1e-5 + QUANTILE_THRESHOLD = 1e-5 + + if max_point_diff > POINT_THRESHOLD or max_quantile_diff > QUANTILE_THRESHOLD: + raise ValueError( + f"Output mismatch detected!\n" + f"Point forecast max diff: {max_point_diff} (threshold: {POINT_THRESHOLD})\n" + f"Quantile forecast max diff: {max_quantile_diff} (threshold: {QUANTILE_THRESHOLD})" + ) + + print("\n✓ All outputs match within acceptable tolerance!") + + # Optional: Print shapes for verification + print("\nOutput shapes:") + print(f"Original point forecast: {point_forecast_orig.shape}") + print(f"Converted point forecast: {point_forecast_conv.shape}") + print(f"Original quantile forecast: {quantile_forecast_orig.shape}") + print(f"Converted quantile forecast: {quantile_forecast_conv.shape}") + + def main(): parser = argparse.ArgumentParser() parser.add_argument( From c3fbff27120b8c3c6a4d19a9095ba2fca994aec1 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 5 Dec 2024 09:21:47 +0100 Subject: [PATCH 120/242] remove dropout_rate --- tests/models/timesfm/test_modeling_timesfm.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index ec7550b043f3..a6cdaf17846e 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -56,7 +56,6 @@ def __init__( intermediate_size: int = 32, head_dim: int = 2, num_heads: int = 2, - dropout_rate: float = 0.1, tolerance: float = 1e-6, rms_norm_eps: float = 1e-6, quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], @@ -78,7 +77,6 @@ def __init__( self.head_dim = head_dim self.num_hidden_layers = num_layers self.num_attention_heads = num_heads - self.dropout_rate = dropout_rate self.tolerance = tolerance self.rms_norm_eps = rms_norm_eps self.use_positional_embedding = use_positional_embedding @@ -106,7 +104,6 @@ def get_config(self): head_dim=self.head_dim, num_layers=self.num_hidden_layers, num_heads=self.num_attention_heads, - dropout_rate=self.dropout_rate, tolerance=self.tolerance, rms_norm_eps=self.rms_norm_eps, use_positional_embedding=self.use_positional_embedding, @@ -126,18 +123,10 @@ def prepare_config_and_inputs(self): config = self.get_config() - return ( - config, - forecast_input, - frequency_input, - ) + return (config, forecast_input, frequency_input) def prepare_config_and_inputs_for_common(self): - ( - config, - forecast_input, - frequency_input, - ) = self.prepare_config_and_inputs() + (config, forecast_input, frequency_input) = self.prepare_config_and_inputs() inputs_dict = { "inputs": forecast_input, From 9e6750cc97426d8e3033557e6062009d40c4b4e4 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 5 Dec 2024 09:53:09 +0100 Subject: [PATCH 121/242] works with torch.Tensor inputs --- .../convert_timesfm_orignal_to_pytorch.py | 10 +++- .../models/timesfm/modeling_timesfm.py | 56 +++++++++++-------- tests/models/timesfm/test_modeling_timesfm.py | 7 ++- 3 files changed, 46 insertions(+), 27 deletions(-) diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py index ac638e9efe66..9d1fd8afe7d0 100644 --- a/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py @@ -164,9 +164,13 @@ def check_outputs(model_path): freq=frequency_input, ) + # Convert inputs to sequence of tensors + forecast_input_tensor = [torch.tensor(ts, dtype=torch.float32) for ts in forecast_input] + frequency_input_tensor = torch.tensor(frequency_input, dtype=torch.long) + # Get predictions from converted model with torch.no_grad(): - outputs = converted_model(inputs=forecast_input, freq=frequency_input, return_dict=True) + outputs = converted_model(inputs=forecast_input_tensor, freq=frequency_input_tensor, return_dict=True) point_forecast_conv = outputs.mean_predictions.numpy() quantile_forecast_conv = outputs.full_predictions.numpy() @@ -213,7 +217,9 @@ def main(): required=True, help="Location to write HF model and tokenizer", ) - parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.") + parser.add_argument( + "--safe_serialization", type=bool, default=True, help="Whether or not to save using `safetensors`." + ) args = parser.parse_args() write_model( model_path=args.output_dir, diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index f2134e90e463..7b664ffd0527 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -24,7 +24,7 @@ import logging import math from dataclasses import dataclass -from typing import Any, List, Sequence, Tuple +from typing import List, Sequence, Tuple import numpy as np import torch @@ -452,10 +452,13 @@ def timesfm_shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Ten def timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]: - """Calculates the moving average using NumPy's convolution function.""" + """Calculates the moving average using PyTorch's convolution function.""" # Pad with zeros to handle initial window positions - arr_padded = np.pad(arr, (window_size - 1, 0), "constant") - smoothed_arr = np.convolve(arr_padded, np.ones(window_size), "valid") / window_size + arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) + # Create a convolution kernel + kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size + # Apply convolution to calculate the moving average + smoothed_arr = F.conv1d(arr_padded.unsqueeze(0).unsqueeze(0), kernel.unsqueeze(0).unsqueeze(0)).squeeze() return [smoothed_arr, arr - smoothed_arr] @@ -700,14 +703,16 @@ def __init__(self, config: TimesFMConfig): # Initialize weights and apply final processing self.post_init() - def _preprocess(self, inputs: Sequence[np.array], freq: Sequence[int]) -> tuple[np.array, np.array, int]: + def _preprocess( + self, inputs: Sequence[torch.Tensor], freq: Sequence[int] + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Formats and pads raw inputs to feed into the model. This function both pads each time series to match the context length, and pads the inputs to meet the SPMD shape requirement. Args: - inputs: A list of 1d Tensors. Each JTensor is the context time series of + inputs: A list of 1d Tensors. Each Tensor is the context time series of a single forecast task. freq: list of frequencies @@ -722,11 +727,11 @@ def _preprocess(self, inputs: Sequence[np.array], freq: Sequence[int]) -> tuple[ for i, ts in enumerate(inputs): input_len = ts.shape[0] - padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float) + padding = torch.zeros(input_len + self.horizon_len, dtype=torch.float32) if input_len < self.context_len: num_front_pad = self.context_len - input_len - ts = np.concatenate([np.zeros(shape=(num_front_pad,), dtype=float), ts], axis=0) - padding = np.concatenate([np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0) + ts = torch.cat([torch.zeros(num_front_pad, dtype=torch.float32), ts], dim=0) + padding = torch.cat([torch.ones(num_front_pad, dtype=torch.float32), padding], dim=0) elif input_len > self.context_len: ts = ts[-self.context_len :] padding = padding[-(self.context_len + self.horizon_len) :] @@ -736,9 +741,9 @@ def _preprocess(self, inputs: Sequence[np.array], freq: Sequence[int]) -> tuple[ inp_freq.append(freq[i]) return ( - np.stack(input_ts, axis=0), - np.stack(input_padding, axis=0), - np.array(inp_freq).astype(np.int32).reshape(-1, 1), + torch.stack(input_ts, dim=0), + torch.stack(input_padding, dim=0), + torch.tensor(inp_freq, dtype=torch.int32).reshape(-1, 1), ) def _postprocess_output( @@ -857,8 +862,8 @@ def decode( def forward( self, - inputs: Sequence[Any], - freq: Sequence[int] | None = None, + inputs: Sequence[torch.Tensor], + freq: Sequence[torch.Tensor | int] | None = None, window_size: int | None = None, forecast_context_len: int | None = None, return_forecast_on_context: bool = False, @@ -899,8 +904,13 @@ def forward( fcontext_len = self.context_len else: fcontext_len = forecast_context_len - inputs = [np.array(ts)[-fcontext_len:] for ts in inputs] - inp_min = np.min([np.min(ts) for ts in inputs]) + + # Get device from first input tensor + device = inputs[0].device + + # Truncate inputs to forecast_context_len + inputs = [ts[-fcontext_len:] for ts in inputs] + inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) if window_size is not None: new_inputs = [] @@ -919,13 +929,15 @@ def forward( input_ts, input_padding, inp_freq = self._preprocess(inputs, freq) - input_ts_in = torch.from_numpy(np.array(input_ts, dtype=np.float32)) - input_padding_in = torch.from_numpy(np.array(input_padding, dtype=np.float32)) - inp_freq_in = torch.from_numpy(np.array(inp_freq, dtype=np.int32)).long() + # Move tensors to the same device as input + input_ts = input_ts.to(device) + input_padding = input_padding.to(device) + inp_freq = inp_freq.to(device) + mean_outputs, full_outputs, last_hidden_state, all_attentions, all_hidden_states = self.decode( - input_ts=input_ts_in, - paddings=input_padding_in, - freq=inp_freq_in, + input_ts=input_ts, + paddings=input_padding, + freq=inp_freq, horizon_len=self.horizon_len, return_forecast_on_context=return_forecast_on_context, output_attentions=output_attentions, diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index a6cdaf17846e..7fd8f01ff247 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -18,6 +18,7 @@ from typing import List import numpy as np +import torch from transformers import TimesFMConfig, is_torch_available from transformers.testing_utils import ( @@ -115,9 +116,9 @@ def get_pipeline_config(self): def prepare_config_and_inputs(self): forecast_input = [ - np.sin(np.linspace(0, 20, 100)), - np.sin(np.linspace(0, 20, 200)), - np.sin(np.linspace(0, 20, 400)), + torch.tensor(np.sin(np.linspace(0, 20, 100)), dtype=torch.float32), + torch.tensor(np.sin(np.linspace(0, 20, 200)), dtype=torch.float32), + torch.tensor(np.sin(np.linspace(0, 20, 400)), dtype=torch.float32), ] frequency_input = [0, 1, 2] From b437e87b8850de90c7f49c42287cdb2e877cff17 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 5 Dec 2024 09:54:44 +0100 Subject: [PATCH 122/242] rename script --- ...m_orignal_to_pytorch.py => convert_timesfm_orignal_to_hf.py} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename src/transformers/models/timesfm/{convert_timesfm_orignal_to_pytorch.py => convert_timesfm_orignal_to_hf.py} (99%) diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py similarity index 99% rename from src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py rename to src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py index 9d1fd8afe7d0..eeed750c337b 100644 --- a/src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py @@ -13,7 +13,7 @@ Sample usage: ``` -python src/transformers/models/timesfm/convert_timesfm_orignal_to_pytorch.py \ +python src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py \ --output_dir /output/path ``` """ From 9f0f086a345901b9f126b0729b4f300b7435c0e4 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 5 Dec 2024 10:04:14 +0100 Subject: [PATCH 123/242] fix docstrings --- .../models/timesfm/modeling_timesfm.py | 47 ++++++++++++------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 7b664ffd0527..74779f6a1587 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -185,7 +185,7 @@ def forward(self, seq_length=None, position=None): class TimesFMAttention(nn.Module): - """Implements the attention used in TimesFM.""" + """Implements the attention used in TimesFM. One key diffrence is that there is _per_dim_scaling of the query.""" def __init__(self, config: TimesFMConfig): super().__init__() @@ -655,7 +655,8 @@ def forward( freq: torch.Tensor, output_attentions: bool = False, output_hidden_states: bool = False, - ) -> TimesFMDecoderOutput: + return_dict: bool = True, + ) -> TimesFMDecoderOutput | tuple[torch.Tensor, ...]: model_input, patched_padding, stats, _ = self._preprocess_input( input_ts=input_ts, input_padding=input_padding, @@ -674,13 +675,22 @@ def forward( else: all_hidden_states = None - return TimesFMDecoderOutput( - last_hidden_state=transformer_output.last_hidden_state, - hidden_states=all_hidden_states, - attentions=transformer_output.attentions if output_attentions else None, - loc=stats[0], - scale=stats[1], - ) + if return_dict: + return TimesFMDecoderOutput( + last_hidden_state=transformer_output.last_hidden_state, + hidden_states=all_hidden_states, + attentions=transformer_output.attentions if output_attentions else None, + loc=stats[0], + scale=stats[1], + ) + else: + return ( + transformer_output.last_hidden_state, + all_hidden_states, + transformer_output.attentions, + stats[0], + stats[1], + ) class TimesFMModelForPrediction(TimesFMPreTrainedModel): @@ -778,7 +788,7 @@ def decode( return_forecast_on_context: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, - ): + ) -> tuple[torch.Tensor, ...]: """Auto-regressive decoding without caching. Args: @@ -799,6 +809,9 @@ def decode( B x H' x (1 + # quantiles). In particular, if return_forecast_on_context is True, H' is H plus the forecastable context length, i.e. context_len - (first) patch_len. + + Raises: + ValueError: If the paddings do not match the input + horizon_len. """ final_out = input_ts context_len = final_out.shape[1] @@ -871,7 +884,7 @@ def forward( output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, - ) -> TimesFMOutputForPrediction: + ) -> TimesFMOutputForPrediction | tuple[torch.Tensor, ...]: """Forecasts on a list of time series. Args: @@ -887,15 +900,15 @@ def forward( when available, i.e. after the first input patch. truncate_negative: truncate to only non-negative values if all the contexts have non-negative values. + output_attentions: Whether to return the attentions. + output_hidden_states: Whether to return the hidden states. + return_dict: Whether to return a TimesFMOutputForPrediction object. Returns: - A tuple for Tensors: - - the mean forecast of size (# inputs, # forecast horizon), - - the full forecast (mean + quantiles) of size + A TimesFMOutputForPrediction object containing: + - the mean forecast of size (# inputs, # forecast horizon), + - the full forecast (mean + quantiles) of size (# inputs, # forecast horizon, 1 + # quantiles). - - Raises: - ValueError: If the checkpoint is not properly loaded. """ if return_dict is None: return_dict = self.config.use_return_dict From f9e5db8df81922c3b49310a6a629fa62cc8899a1 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 5 Dec 2024 10:27:16 +0100 Subject: [PATCH 124/242] fix freq when window_size is given --- .../models/timesfm/modeling_timesfm.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 74779f6a1587..b3de01cd3682 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -133,11 +133,11 @@ class TimesFMPositionalEmbedding(nn.Module): """Generates position embedding for a given 1-d sequence. Attributes: + embedding_dims: Dimension of the embedding to be generated. min_timescale: Start of the geometric index. Determines the periodicity of - the added signal. + the added signal. Defaults to 1. max_timescale: End of the geometric index. Determines the frequency of the - added signal. - embedding_dims: Dimension of the embedding to be generated. + added signal. Defaults to 10_000. """ def __init__( @@ -889,10 +889,9 @@ def forward( Args: inputs: list of time series forecast contexts. Each context time series - should be in a format convertible to Tensor. - freq: frequency of each context time series. 0 for high frequency - (default), 1 for medium, and 2 for low. Notice this is different from - the `freq` required by `forecast_on_df`. + should be a torch Tensor of potentially different context lengths. + freq: frequency of each context time series in the inputs. 0 for high frequency + (default), 1 for medium, and 2 for low. window_size: window size of trend + residual decomposition. If None then we do not do decomposition. forecast_context_len: optional max context length. @@ -927,9 +926,15 @@ def forward( if window_size is not None: new_inputs = [] - for ts in inputs: + if freq is not None: + new_freqs = [] + for i, ts in enumerate(inputs): new_inputs.extend(timesfm_moving_average(ts, window_size)) + if freq is not None: + new_freqs.extend([freq[i]] * 2) inputs = new_inputs + if freq is not None: + freq = new_freqs if freq is None: logging.info("No frequency provided via `freq`. Default to high (0).") From c8703ff8b958def15125103440a50a4d3774a8bb Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 5 Dec 2024 10:48:17 +0100 Subject: [PATCH 125/242] add loss --- .../models/timesfm/modeling_timesfm.py | 33 +++++++++++++++---- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index b3de01cd3682..b93810265e3d 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -26,7 +26,6 @@ from dataclasses import dataclass from typing import List, Sequence, Tuple -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -38,14 +37,15 @@ @dataclass class TimesFMDecoderOutput(BaseModelOutput): - loc: np.ndarray | None = None - scale: np.ndarray | None = None + loc: torch.Tensor | None = None + scale: torch.Tensor | None = None @dataclass class TimesFMOutputForPrediction(BaseModelOutput): - mean_predictions: np.ndarray | None = None - full_predictions: np.ndarray | None = None + mean_predictions: torch.Tensor | None = None + full_predictions: torch.Tensor | None = None + loss: float | None = None class TimesFMTransformerMLP(nn.Module): @@ -873,11 +873,21 @@ def decode( decoder_output.hidden_states, ) + @staticmethod + def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor, quantiles: List[float]) -> torch.Tensor: + losses = [] + for q in quantiles: + errors = targets - predictions + loss = torch.max((q - 1) * errors, q * errors) + losses.append(loss.mean()) + return torch.stack(losses).mean() + def forward( self, inputs: Sequence[torch.Tensor], freq: Sequence[torch.Tensor | int] | None = None, window_size: int | None = None, + future_target: torch.Tensor | None = None, forecast_context_len: int | None = None, return_forecast_on_context: bool = False, truncate_negative: bool = False, @@ -894,6 +904,7 @@ def forward( (default), 1 for medium, and 2 for low. window_size: window size of trend + residual decomposition. If None then we do not do decomposition. + future_target: optional future target time series to be used for loss computation. forecast_context_len: optional max context length. return_forecast_on_context: True to return the forecast on the context when available, i.e. after the first input patch. @@ -908,6 +919,7 @@ def forward( - the mean forecast of size (# inputs, # forecast horizon), - the full forecast (mean + quantiles) of size (# inputs, # forecast horizon, 1 + # quantiles). + - loss: the mean squared error loss + quantile loss if future_target is provided. """ if return_dict is None: return_dict = self.config.use_return_dict @@ -927,7 +939,7 @@ def forward( if window_size is not None: new_inputs = [] if freq is not None: - new_freqs = [] + new_freqs = [] for i, ts in enumerate(inputs): new_inputs.extend(timesfm_moving_average(ts, window_size)) if freq is not None: @@ -969,6 +981,12 @@ def forward( mean_outputs = torch.maximum(mean_outputs, 0.0) full_outputs = torch.maximum(full_outputs, 0.0) + loss = None + if future_target is not None: + mse_loss = torch.nn.functional.mse_loss(mean_outputs, future_target) + quantile_loss = self._quantile_loss(full_outputs, future_target, self.config.quantiles) + loss = mse_loss + quantile_loss + if return_dict: return TimesFMOutputForPrediction( last_hidden_state=last_hidden_state, @@ -976,6 +994,7 @@ def forward( hidden_states=all_hidden_states if output_hidden_states else None, mean_predictions=mean_outputs, full_predictions=full_outputs, + loss=loss, ) else: return_tuple = [last_hidden_state] @@ -983,5 +1002,5 @@ def forward( return_tuple.append(all_hidden_states) if output_attentions: return_tuple.append(all_attentions) - return_tuple += [mean_outputs, full_outputs] + return_tuple += [mean_outputs, full_outputs, loss] return tuple(return_tuple) From 8f6c2e1a7e5f4f5da2bb41d2a7bf473fe634e7bb Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 5 Dec 2024 11:00:43 +0100 Subject: [PATCH 126/242] fix _quantile_loss --- .../models/timesfm/modeling_timesfm.py | 38 ++++++------------- 1 file changed, 12 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index b93810265e3d..e20dfa879599 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -45,7 +45,7 @@ class TimesFMDecoderOutput(BaseModelOutput): class TimesFMOutputForPrediction(BaseModelOutput): mean_predictions: torch.Tensor | None = None full_predictions: torch.Tensor | None = None - loss: float | None = None + loss: torch.Tensor | float | None = None class TimesFMTransformerMLP(nn.Module): @@ -74,12 +74,7 @@ def forward(self, x, paddings=None): class TimesFMResidualBlock(nn.Module): """TimesFM residual block.""" - def __init__( - self, - input_dims, - hidden_dims, - output_dims, - ): + def __init__(self, input_dims, hidden_dims, output_dims): super().__init__() self.input_dims = input_dims self.hidden_dims = hidden_dims @@ -572,7 +567,7 @@ def __init__(self, config: TimesFMConfig): self.input_ff_layer = TimesFMResidualBlock( input_dims=2 * config.patch_len, output_dims=config.model_dim, - hidden_dims=config.model_dim, + hidden_dims=config.intermediate_size, ) self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.model_dim) self.stacked_transformer = TimesFMStackedDecoder(config=config) @@ -605,15 +600,8 @@ def _forward_transform( return outputs, (mu, sigma) def _preprocess_input( - self, - input_ts: torch.Tensor, - input_padding: torch.Tensor, - ) -> tuple[ - torch.Tensor, - torch.Tensor, - tuple[torch.Tensor, torch.Tensor] | None, - torch.Tensor, - ]: + self, input_ts: torch.Tensor, input_padding: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], torch.Tensor]: """Preprocess input for stacked transformer.""" # Reshape into patches (using view for efficiency) @@ -707,7 +695,7 @@ def __init__(self, config: TimesFMConfig): self.horizon_ff_layer = TimesFMResidualBlock( input_dims=config.model_dim, output_dims=config.horizon_len * (1 + len(config.quantiles)), - hidden_dims=config.model_dim, + hidden_dims=config.intermediate_size, ) # Initialize weights and apply final processing @@ -757,9 +745,7 @@ def _preprocess( ) def _postprocess_output( - self, - model_output: torch.Tensor, - stats: tuple[torch.Tensor, torch.Tensor], + self, model_output: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor] ) -> torch.Tensor: """Postprocess output of stacked transformer.""" @@ -873,10 +859,9 @@ def decode( decoder_output.hidden_states, ) - @staticmethod - def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor, quantiles: List[float]) -> torch.Tensor: + def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: losses = [] - for q in quantiles: + for q in self.config.quantiles: errors = targets - predictions loss = torch.max((q - 1) * errors, q * errors) losses.append(loss.mean()) @@ -972,6 +957,7 @@ def forward( return_forecast_on_context=return_forecast_on_context, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + max_len=fcontext_len, ) if window_size is not None: @@ -983,8 +969,8 @@ def forward( loss = None if future_target is not None: - mse_loss = torch.nn.functional.mse_loss(mean_outputs, future_target) - quantile_loss = self._quantile_loss(full_outputs, future_target, self.config.quantiles) + mse_loss = F.mse_loss(mean_outputs, future_target) + quantile_loss = self._quantile_loss(full_outputs[:, :, 1:], future_target) loss = mse_loss + quantile_loss if return_dict: From b3198738a82df93cd79c3b21938c639c4bd2b601 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 5 Dec 2024 11:06:53 +0100 Subject: [PATCH 127/242] formatting --- docs/source/en/model_doc/timesfm.md | 5 +---- src/transformers/models/timesfm/modeling_timesfm.py | 7 ------- tests/models/timesfm/test_modeling_timesfm.py | 11 +---------- 3 files changed, 2 insertions(+), 21 deletions(-) diff --git a/docs/source/en/model_doc/timesfm.md b/docs/source/en/model_doc/timesfm.md index 76cf1f8afef2..4e2ee1ae0c61 100644 --- a/docs/source/en/model_doc/timesfm.md +++ b/docs/source/en/model_doc/timesfm.md @@ -25,11 +25,8 @@ The abstract from the paper is the following: *Motivated by recent advances in large language models for Natural Language Processing (NLP), we design a time-series foundation model for forecasting whose out-of-the-box zero-shot performance on a variety of public datasets comes close to the accuracy of state-of-the-art supervised forecasting models for each individual dataset. Our model is based on pretraining a patched-decoder style attention model on a large time-series corpus, and can work well across different forecasting history lengths, prediction lengths and temporal granularities.* -Tips: - - -This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). +This model was contributed by [kashif](https://huggingface.co/kashif). The original code can be found [here](https://github.com/google-research/timesfm). diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index e20dfa879599..ab8894730ca8 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -14,13 +14,6 @@ # limitations under the License. """PyTorch TimesFM model.""" - -#################################################### -# PyTorch Models are constructed by sub-classing -# - torch.nn.Module for the layers and -# - PreTrainedModel for the models (it-self a sub-class of nn.Module) -#################################################### - import logging import math from dataclasses import dataclass diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index 7fd8f01ff247..da534f4092ea 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -21,25 +21,16 @@ import torch from transformers import TimesFMConfig, is_torch_available -from transformers.testing_utils import ( - require_torch, - torch_device, -) +from transformers.testing_utils import require_torch, torch_device from transformers.utils import is_torch_fx_available -# from ...generation.test_utils import GenerationTesterMixin -# define our own GenerationTesters from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin -# from ...test_pipeline_mixin import PipelineTesterMixin - - if is_torch_fx_available(): pass - if is_torch_available(): from transformers import TimesFMModelForPrediction From 3bd0827e4d1363898920bf4d13329df633194f80 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 5 Dec 2024 11:17:56 +0100 Subject: [PATCH 128/242] fix isort --- src/transformers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a38dc203485d..9ac362efb0a4 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3742,8 +3742,8 @@ ) _import_structure["models.timesfm"].extend( [ - "TimesFMModelForPrediction", "TimesFMDecoder", + "TimesFMModelForPrediction", "TimesFMPreTrainedModel", ] ) From 0d4325ee7740f4a2587a7162536baeeece4b9fee Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 5 Dec 2024 12:13:47 +0100 Subject: [PATCH 129/242] add weight init --- .../models/timesfm/modeling_timesfm.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index ab8894730ca8..b64bde1288da 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -546,6 +546,51 @@ def _init_weights(self, module): elif isinstance(module, TimesFMRMSNorm): nn.init.zeros_(module.weight) + elif isinstance(module, TimesFMTransformerMLP): + # Initialize gate projection + module.gate_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor) + if module.gate_proj.bias is not None: + nn.init.zeros_(module.gate_proj.bias) + + # Initialize down projection + module.down_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor) + if module.down_proj.bias is not None: + nn.init.zeros_(module.down_proj.bias) + + # Initialize layer norm + nn.init.ones_(module.layer_norm.weight) + nn.init.zeros_(module.layer_norm.bias) + + elif isinstance(module, TimesFMAttention): + # Initialize qkv projection + module.qkv_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor) + if module.qkv_proj.bias is not None: + nn.init.zeros_(module.qkv_proj.bias) + + # Initialize output projection + module.o_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor) + if module.o_proj.bias is not None: + nn.init.zeros_(module.o_proj.bias) + + # Initialize scaling parameter + nn.init.ones_(module.scaling) + + elif isinstance(module, TimesFMResidualBlock): + # Initialize hidden layer + module.hidden_layer[0].weight.data.normal_(mean=0, std=self.config.initializer_factor) + if module.hidden_layer[0].bias is not None: + nn.init.zeros_(module.hidden_layer[0].bias) + + # Initialize output layer + module.output_layer.weight.data.normal_(mean=0, std=self.config.initializer_factor) + if module.output_layer.bias is not None: + nn.init.zeros_(module.output_layer.bias) + + # Initialize residual layer + module.residual_layer.weight.data.normal_(mean=0, std=self.config.initializer_factor) + if module.residual_layer.bias is not None: + nn.init.zeros_(module.residual_layer.bias) + elif isinstance(module, TimesFMPositionalEmbedding): pass From 4212ef81f3fbe012afa02be18c561c5c88f0d369 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 8 Dec 2024 13:44:02 +0100 Subject: [PATCH 130/242] add support for sdpa and flash_attention_2 --- .../models/timesfm/configuration_timesfm.py | 4 + .../models/timesfm/modeling_timesfm.py | 194 +++++++++++++++++- 2 files changed, 187 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index eb135c03c968..012315882957 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -62,6 +62,8 @@ class TimesFMConfig(PretrainedConfig): The quantiles to predict. pad_val (`float`, *optional*, defaults to 1123581321.0): The value used to pad the predictions. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for the attention scores. use_positional_embedding (`bool`, *optional*, defaults to `True`): Whether to add positional embeddings. initializer_factor (`float`, *optional*, defaults to 1.0): @@ -93,6 +95,7 @@ def __init__( rms_norm_eps: float = 1e-6, quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], pad_val: float = 1123581321.0, + attention_dropout: float = 0.0, use_positional_embedding: bool = True, initializer_factor: float = 1.0, **kwargs, @@ -110,6 +113,7 @@ def __init__( self.num_heads = num_heads self.tolerance = tolerance self.rms_norm_eps = rms_norm_eps + self.attention_dropout = attention_dropout self.use_positional_embedding = use_positional_embedding self.initializer_factor = initializer_factor diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index b64bde1288da..38c8dc9bb6cc 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -25,9 +25,14 @@ from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel +from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10 from .configuration_timesfm import TimesFMConfig +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + @dataclass class TimesFMDecoderOutput(BaseModelOutput): loc: torch.Tensor | None = None @@ -188,9 +193,7 @@ def __init__(self, config: TimesFMConfig): self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = nn.Parameter( - torch.empty((self.head_dim,), dtype=torch.float32), - ) + self.scaling = nn.Parameter(torch.empty((self.head_dim,))) self.qkv_proj = nn.Linear( self.hidden_size, @@ -198,12 +201,8 @@ def __init__(self, config: TimesFMConfig): ) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size) - def _per_dim_scaling(self, query: torch.Tensor) -> torch.Tensor: - # [batch_size, n_local_heads, input_len, head_dim] - r_softplus_0 = 1.442695041 - softplus_func = torch.nn.Softplus() - scale = r_softplus_0 / math.sqrt(self.head_dim) - scale = scale * softplus_func(self.scaling) + def _scale_query(self, query: torch.Tensor) -> torch.Tensor: + scale = F.softplus(self.scaling).mul(1.442695041 / math.sqrt(self.head_dim)) return query * scale[None, None, None, :] def forward( @@ -225,7 +224,7 @@ def forward( xq = xq.view(batch_size, -1, self.num_heads, self.head_dim) xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim) xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim) - xq = self._per_dim_scaling(xq) + xq = self._scale_query(xq) # Write new kv cache. # [batch_size, input_len, n_local_kv_heads, head_dim] @@ -272,12 +271,182 @@ def forward( return output, scores +class TimesFMFlashAttention2(TimesFMAttention): + """TimesFM attention implementation using Flash Attention 2.""" + + def __init__(self, config: TimesFMConfig): + super().__init__(config) + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + kv_write_indices: torch.Tensor | None = None, + kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + if output_attentions: + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + kv_write_indices=kv_write_indices, + kv_cache=kv_cache, + output_attentions=output_attentions, + ) + + batch_size, seq_length, _ = hidden_states.shape + + # Project to q, k, v + qkv = self.qkv_proj(hidden_states) + xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # Reshape + xq = xq.view(batch_size, seq_length, self.num_heads, self.head_dim) + xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim) + xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim) + + # Scale query using the model's learned scaling + xq = self._scale_query(xq) + + # Handle KV cache + if kv_cache is not None and kv_write_indices is not None: + k_cache, v_cache = kv_cache + k_cache.index_copy_(1, kv_write_indices, xk) + v_cache.index_copy_(1, kv_write_indices, xv) + key = k_cache + value = v_cache + else: + key = xk + value = xv + + # Handle grouped attention + if self.num_queries_per_kv > 1: + key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) + value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2) + + # Transpose for attention + query = xq.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # Run flash attention + attn_output = _flash_attention_forward( + query, + key, + value, + attention_mask, + seq_length, + dropout_p=self.attention_dropout if self.training else 0.0, + softmax_scale=1, # Set to 1.0 to disable default scaling + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + # Reshape output + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, -1) + attn_output = self.o_proj(attn_output) + + return attn_output, None + + +class TimesFMSdpaAttention(TimesFMAttention): + """TimesFM attention implementation using torch.nn.functional.scaled_dot_product_attention.""" + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + kv_write_indices: torch.Tensor | None = None, + kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + if output_attentions: + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + kv_write_indices=kv_write_indices, + kv_cache=kv_cache, + output_attentions=output_attentions, + ) + + hidden_states_shape = hidden_states.shape + batch_size, seq_length, _ = hidden_states_shape + + # Project to queries, keys, values + qkv = self.qkv_proj(hidden_states) + xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # Reshape: [batch_size, seq_length, num_heads * head_dim] -> [batch_size, seq_length, num_heads, head_dim] + xq = xq.view(batch_size, seq_length, self.num_heads, self.head_dim) + xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim) + xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim) + + # Scale query exactly as in original + xq = self._scale_query(xq) + + # Handle KV cache + if kv_cache is not None and kv_write_indices is not None: + k_cache, v_cache = kv_cache + k_cache.index_copy_(1, kv_write_indices, xk) + v_cache.index_copy_(1, kv_write_indices, xv) + key = k_cache + value = v_cache + else: + key = xk + value = xv + + # Handle grouped attention + if self.num_queries_per_kv > 1: + key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) + value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2) + + # Transpose for attention: [batch_size, num_heads, seq_length, head_dim] + query = xq.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # Make inputs contiguous + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # Run scaled dot-product attention + # Note: attention_mask should already be in the correct format from TimesFMStackedDecoder + attn_output = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=False, # We use the provided attention mask + scale=1, # We already scaled the query + ) + + # Reshape output: [batch_size, seq_length, hidden_size] + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, -1) + attn_output = self.o_proj(attn_output) + + return attn_output, None + + +TIMESFM_ATTENTION_CLASSES = { + "eager": TimesFMAttention, + "flash_attention_2": TimesFMFlashAttention2, + "sdpa": TimesFMSdpaAttention, +} + + class TimesFMDecoderLayer(nn.Module): """Transformer layer.""" def __init__(self, config: TimesFMConfig): super().__init__() - self.self_attn = TimesFMAttention(config) + + if config._attn_implementation not in TIMESFM_ATTENTION_CLASSES: + raise ValueError(f"Unknown attention implementation: {config._attn_implementation}") + attention_class = TIMESFM_ATTENTION_CLASSES[config._attn_implementation] + + self.self_attn = attention_class(config) self.mlp = TimesFMTransformerMLP(config.model_dim, config.intermediate_size) self.input_layernorm = TimesFMRMSNorm(config.model_dim, eps=config.rms_norm_eps) @@ -529,6 +698,7 @@ class TimesFMPreTrainedModel(PreTrainedModel): config_class = TimesFMConfig base_model_prefix = "timesfm" main_input_name = "inputs" + _supports_sdpa = True def _init_weights(self, module): if isinstance(module, nn.Embedding): @@ -720,6 +890,8 @@ def forward( class TimesFMModelForPrediction(TimesFMPreTrainedModel): + """TimesFM model for quantile and mean prediction.""" + def __init__(self, config: TimesFMConfig): super().__init__(config) From 9739e4bf78f91bd69219a1a2d14851c9163e5ffa Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 8 Dec 2024 14:20:51 +0100 Subject: [PATCH 131/242] fixes for flash_attention --- .../models/timesfm/modeling_timesfm.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 38c8dc9bb6cc..e1da8515b6b9 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -330,6 +330,14 @@ def forward( key = key.transpose(1, 2) value = value.transpose(1, 2) + # Convert attention mask to proper format for Flash Attention + if attention_mask is not None: + # Convert from [batch_size, 1, seq_length, seq_length] to [batch_size, seq_length] + # by checking which positions are not allowed to attend to any other position + attention_mask = attention_mask.squeeze(1) # [batch_size, seq_length, seq_length] + attention_mask = ~attention_mask.all(dim=-1) # [batch_size, seq_length] + attention_mask = attention_mask.to(query.dtype) + # Run flash attention attn_output = _flash_attention_forward( query, @@ -337,9 +345,10 @@ def forward( value, attention_mask, seq_length, - dropout_p=self.attention_dropout if self.training else 0.0, + dropout=self.attention_dropout if self.training else 0.0, softmax_scale=1, # Set to 1.0 to disable default scaling use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=False, ) # Reshape output @@ -699,6 +708,7 @@ class TimesFMPreTrainedModel(PreTrainedModel): base_model_prefix = "timesfm" main_input_name = "inputs" _supports_sdpa = True + _supports_flash_attn_2 = True def _init_weights(self, module): if isinstance(module, nn.Embedding): @@ -938,8 +948,8 @@ def _preprocess( padding = torch.zeros(input_len + self.horizon_len, dtype=torch.float32) if input_len < self.context_len: num_front_pad = self.context_len - input_len - ts = torch.cat([torch.zeros(num_front_pad, dtype=torch.float32), ts], dim=0) - padding = torch.cat([torch.ones(num_front_pad, dtype=torch.float32), padding], dim=0) + ts = torch.cat([torch.zeros(num_front_pad, dtype=torch.float32, device=ts.device), ts], dim=0) + padding = torch.cat([torch.ones(num_front_pad, dtype=torch.float32, device=padding.device), padding], dim=0) elif input_len > self.context_len: ts = ts[-self.context_len :] padding = padding[-(self.context_len + self.horizon_len) :] From 33cee018fd9990c6999aec32a86353649d830fa2 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 8 Dec 2024 14:21:26 +0100 Subject: [PATCH 132/242] formatting --- src/transformers/models/timesfm/modeling_timesfm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index e1da8515b6b9..3f60f435b482 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -949,7 +949,9 @@ def _preprocess( if input_len < self.context_len: num_front_pad = self.context_len - input_len ts = torch.cat([torch.zeros(num_front_pad, dtype=torch.float32, device=ts.device), ts], dim=0) - padding = torch.cat([torch.ones(num_front_pad, dtype=torch.float32, device=padding.device), padding], dim=0) + padding = torch.cat( + [torch.ones(num_front_pad, dtype=torch.float32, device=padding.device), padding], dim=0 + ) elif input_len > self.context_len: ts = ts[-self.context_len :] padding = padding[-(self.context_len + self.horizon_len) :] From bce6405deadf5c9f0b3d1a938407620435e6a899 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 9 Dec 2024 10:40:04 +0100 Subject: [PATCH 133/242] remove flash_attention --- .../timesfm/convert_timesfm_orignal_to_hf.py | 35 ++++-- .../models/timesfm/modeling_timesfm.py | 102 +----------------- 2 files changed, 27 insertions(+), 110 deletions(-) diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py index eeed750c337b..7284f9c24c17 100644 --- a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py @@ -26,7 +26,7 @@ def write_model(model_path, safe_serialization=True): tfm = timesfm.TimesFm( hparams=timesfm.TimesFmHparams( - backend="cpu", + backend="cuda" if torch.cuda.is_available() else "cpu", per_core_batch_size=32, horizon_len=128, ), @@ -139,7 +139,7 @@ def check_outputs(model_path): # Load original model tfm = timesfm.TimesFm( hparams=timesfm.TimesFmHparams( - backend="cpu", + backend="cuda" if torch.cuda.is_available() else "cpu", per_core_batch_size=32, horizon_len=128, ), @@ -147,7 +147,11 @@ def check_outputs(model_path): ) # Load converted model - converted_model = TimesFMModelForPrediction.from_pretrained(model_path) + converted_model = TimesFMModelForPrediction.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + attn_implementation="sdpa", + ).to("cuda" if torch.cuda.is_available() else "cpu") converted_model.eval() # Set to evaluation mode # Create test inputs @@ -165,14 +169,19 @@ def check_outputs(model_path): ) # Convert inputs to sequence of tensors - forecast_input_tensor = [torch.tensor(ts, dtype=torch.float32) for ts in forecast_input] - frequency_input_tensor = torch.tensor(frequency_input, dtype=torch.long) + forecast_input_tensor = [ + torch.tensor(ts, dtype=torch.bfloat16).to("cuda" if torch.cuda.is_available() else "cpu") + for ts in forecast_input + ] + frequency_input_tensor = torch.tensor(frequency_input, dtype=torch.long).to( + "cuda" if torch.cuda.is_available() else "cpu" + ) # Get predictions from converted model with torch.no_grad(): outputs = converted_model(inputs=forecast_input_tensor, freq=frequency_input_tensor, return_dict=True) - point_forecast_conv = outputs.mean_predictions.numpy() - quantile_forecast_conv = outputs.full_predictions.numpy() + point_forecast_conv = outputs.mean_predictions.float().cpu().numpy() + quantile_forecast_conv = outputs.full_predictions.float().cpu().numpy() # Compare outputs point_forecast_diff = np.abs(point_forecast_orig - point_forecast_conv) @@ -221,11 +230,15 @@ def main(): "--safe_serialization", type=bool, default=True, help="Whether or not to save using `safetensors`." ) args = parser.parse_args() - write_model( - model_path=args.output_dir, - safe_serialization=args.safe_serialization, - ) + # if the saved model file exists, skip the conversion + if os.path.exists(os.path.join(args.output_dir, "model.safetensors" if args.safe_serialization else "model.bin")): + print(f"Model already exists in {args.output_dir}, skipping conversion.") + else: + write_model( + model_path=args.output_dir, + safe_serialization=args.safe_serialization, + ) check_outputs(args.output_dir) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 3f60f435b482..f388e456bc8c 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -25,14 +25,9 @@ from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel -from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10 from .configuration_timesfm import TimesFMConfig -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - @dataclass class TimesFMDecoderOutput(BaseModelOutput): loc: torch.Tensor | None = None @@ -271,93 +266,6 @@ def forward( return output, scores -class TimesFMFlashAttention2(TimesFMAttention): - """TimesFM attention implementation using Flash Attention 2.""" - - def __init__(self, config: TimesFMConfig): - super().__init__(config) - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - kv_write_indices: torch.Tensor | None = None, - kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, - output_attentions: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - if output_attentions: - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - kv_write_indices=kv_write_indices, - kv_cache=kv_cache, - output_attentions=output_attentions, - ) - - batch_size, seq_length, _ = hidden_states.shape - - # Project to q, k, v - qkv = self.qkv_proj(hidden_states) - xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - # Reshape - xq = xq.view(batch_size, seq_length, self.num_heads, self.head_dim) - xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim) - xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim) - - # Scale query using the model's learned scaling - xq = self._scale_query(xq) - - # Handle KV cache - if kv_cache is not None and kv_write_indices is not None: - k_cache, v_cache = kv_cache - k_cache.index_copy_(1, kv_write_indices, xk) - v_cache.index_copy_(1, kv_write_indices, xv) - key = k_cache - value = v_cache - else: - key = xk - value = xv - - # Handle grouped attention - if self.num_queries_per_kv > 1: - key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) - value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2) - - # Transpose for attention - query = xq.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - # Convert attention mask to proper format for Flash Attention - if attention_mask is not None: - # Convert from [batch_size, 1, seq_length, seq_length] to [batch_size, seq_length] - # by checking which positions are not allowed to attend to any other position - attention_mask = attention_mask.squeeze(1) # [batch_size, seq_length, seq_length] - attention_mask = ~attention_mask.all(dim=-1) # [batch_size, seq_length] - attention_mask = attention_mask.to(query.dtype) - - # Run flash attention - attn_output = _flash_attention_forward( - query, - key, - value, - attention_mask, - seq_length, - dropout=self.attention_dropout if self.training else 0.0, - softmax_scale=1, # Set to 1.0 to disable default scaling - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=False, - ) - - # Reshape output - attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, -1) - attn_output = self.o_proj(attn_output) - - return attn_output, None - - class TimesFMSdpaAttention(TimesFMAttention): """TimesFM attention implementation using torch.nn.functional.scaled_dot_product_attention.""" @@ -440,7 +348,6 @@ def forward( TIMESFM_ATTENTION_CLASSES = { "eager": TimesFMAttention, - "flash_attention_2": TimesFMFlashAttention2, "sdpa": TimesFMSdpaAttention, } @@ -708,7 +615,6 @@ class TimesFMPreTrainedModel(PreTrainedModel): base_model_prefix = "timesfm" main_input_name = "inputs" _supports_sdpa = True - _supports_flash_attn_2 = True def _init_weights(self, module): if isinstance(module, nn.Embedding): @@ -945,13 +851,11 @@ def _preprocess( for i, ts in enumerate(inputs): input_len = ts.shape[0] - padding = torch.zeros(input_len + self.horizon_len, dtype=torch.float32) + padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device) if input_len < self.context_len: num_front_pad = self.context_len - input_len - ts = torch.cat([torch.zeros(num_front_pad, dtype=torch.float32, device=ts.device), ts], dim=0) - padding = torch.cat( - [torch.ones(num_front_pad, dtype=torch.float32, device=padding.device), padding], dim=0 - ) + ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0) elif input_len > self.context_len: ts = ts[-self.context_len :] padding = padding[-(self.context_len + self.horizon_len) :] From fb33f35ffdeb5d3401aad82f74a3fb4663e85b95 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 9 Dec 2024 11:03:25 +0100 Subject: [PATCH 134/242] fix tests --- src/transformers/models/timesfm/modeling_timesfm.py | 8 ++++++-- tests/models/timesfm/test_modeling_timesfm.py | 12 +++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index f388e456bc8c..35cfd0d537b7 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -177,6 +177,8 @@ class TimesFMAttention(nn.Module): def __init__(self, config: TimesFMConfig): super().__init__() + self.attention_dropout = config.attention_dropout + self.num_heads = config.num_heads self.num_kv_heads = config.num_heads @@ -260,7 +262,7 @@ def forward( output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) output = self.o_proj(output) - if output_attentions: + if not output_attentions: scores = None return output, scores @@ -373,6 +375,7 @@ def forward( paddings: torch.Tensor, kv_write_indices: torch.Tensor | None = None, kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, + output_attentions: bool = False, ) -> torch.Tensor: # Self Attention residual = hidden_states @@ -382,6 +385,7 @@ def forward( attention_mask=attention_mask, kv_write_indices=kv_write_indices, kv_cache=kv_cache, + output_attentions=output_attentions, ) hidden_states = residual + hidden_states @@ -423,6 +427,7 @@ def forward( paddings=paddings, kv_write_indices=kv_write_indices, kv_cache=kv_cache, + output_attentions=output_attentions, ) if output_attentions: all_attentions.append(scores) @@ -727,7 +732,6 @@ def _preprocess_input( self, input_ts: torch.Tensor, input_padding: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], torch.Tensor]: """Preprocess input for stacked transformer.""" - # Reshape into patches (using view for efficiency) bsize = input_ts.shape[0] patched_inputs = input_ts.view(bsize, -1, self.config.patch_len) diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index da534f4092ea..ea734402c1e5 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -107,15 +107,13 @@ def get_pipeline_config(self): def prepare_config_and_inputs(self): forecast_input = [ - torch.tensor(np.sin(np.linspace(0, 20, 100)), dtype=torch.float32), - torch.tensor(np.sin(np.linspace(0, 20, 200)), dtype=torch.float32), - torch.tensor(np.sin(np.linspace(0, 20, 400)), dtype=torch.float32), + torch.tensor(np.sin(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device), + torch.tensor(np.cos(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device), + torch.tensor(np.tan(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device), ] - frequency_input = [0, 1, 2] + frequency_input = torch.tensor([0, 1, 2], dtype=torch.long, device=torch_device) - config = self.get_config() - - return (config, forecast_input, frequency_input) + return (self.get_config(), torch.stack(forecast_input, dim=0), frequency_input) def prepare_config_and_inputs_for_common(self): (config, forecast_input, frequency_input) = self.prepare_config_and_inputs() From b41c3686a7dd7087b6a23685ec811cea41216dfa Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 9 Dec 2024 11:06:31 +0100 Subject: [PATCH 135/242] fix file name --- .../models/timesfm/convert_timesfm_orignal_to_hf.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py index 7284f9c24c17..f791fb0236aa 100644 --- a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py @@ -3,9 +3,10 @@ import shutil import numpy as np -import timesfm + import torch +import timesfm from transformers import TimesFMConfig, TimesFMModelForPrediction @@ -232,7 +233,9 @@ def main(): args = parser.parse_args() # if the saved model file exists, skip the conversion - if os.path.exists(os.path.join(args.output_dir, "model.safetensors" if args.safe_serialization else "model.bin")): + if os.path.exists( + os.path.join(args.output_dir, "model.safetensors" if args.safe_serialization else "pytorch_model.bin") + ): print(f"Model already exists in {args.output_dir}, skipping conversion.") else: write_model( From 9aad1013f036bd8f24cfcd65170e479d1686d4ae Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 9 Dec 2024 12:56:22 +0100 Subject: [PATCH 136/242] fix quantile loss --- .../models/timesfm/modeling_timesfm.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 35cfd0d537b7..3bbba4885c5e 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -851,7 +851,7 @@ def _preprocess( - the number of padded examples for SPMD so that each core has the same number (a multiple of `batch_size`) of examples. """ - input_ts, input_padding, inp_freq = [], [], [] + input_ts, input_padding = [], [] for i, ts in enumerate(inputs): input_len = ts.shape[0] @@ -866,12 +866,11 @@ def _preprocess( input_ts.append(ts) input_padding.append(padding) - inp_freq.append(freq[i]) return ( torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0), - torch.tensor(inp_freq, dtype=torch.int32).reshape(-1, 1), + torch.tensor(freq, dtype=torch.int32, device=input_ts[0].device).reshape(-1, 1), ) def _postprocess_output( @@ -991,8 +990,8 @@ def decode( def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: losses = [] - for q in self.config.quantiles: - errors = targets - predictions + for i, q in enumerate(self.config.quantiles): + errors = targets - predictions[:, :, i] loss = torch.max((q - 1) * errors, q * errors) losses.append(loss.mean()) return torch.stack(losses).mean() @@ -1074,11 +1073,6 @@ def forward( input_ts, input_padding, inp_freq = self._preprocess(inputs, freq) - # Move tensors to the same device as input - input_ts = input_ts.to(device) - input_padding = input_padding.to(device) - inp_freq = inp_freq.to(device) - mean_outputs, full_outputs, last_hidden_state, all_attentions, all_hidden_states = self.decode( input_ts=input_ts, paddings=input_padding, From be8922fa97925051d8ba9f186fc513d48de92530 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 9 Dec 2024 14:26:52 +0100 Subject: [PATCH 137/242] added initial TimesFMModelIntegrationTests --- .../models/timesfm/modeling_timesfm.py | 14 +++++--- tests/models/timesfm/test_modeling_timesfm.py | 36 +++++++++++++++++-- 2 files changed, 43 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 3bbba4885c5e..35cfd0d537b7 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -851,7 +851,7 @@ def _preprocess( - the number of padded examples for SPMD so that each core has the same number (a multiple of `batch_size`) of examples. """ - input_ts, input_padding = [], [] + input_ts, input_padding, inp_freq = [], [], [] for i, ts in enumerate(inputs): input_len = ts.shape[0] @@ -866,11 +866,12 @@ def _preprocess( input_ts.append(ts) input_padding.append(padding) + inp_freq.append(freq[i]) return ( torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0), - torch.tensor(freq, dtype=torch.int32, device=input_ts[0].device).reshape(-1, 1), + torch.tensor(inp_freq, dtype=torch.int32).reshape(-1, 1), ) def _postprocess_output( @@ -990,8 +991,8 @@ def decode( def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: losses = [] - for i, q in enumerate(self.config.quantiles): - errors = targets - predictions[:, :, i] + for q in self.config.quantiles: + errors = targets - predictions loss = torch.max((q - 1) * errors, q * errors) losses.append(loss.mean()) return torch.stack(losses).mean() @@ -1073,6 +1074,11 @@ def forward( input_ts, input_padding, inp_freq = self._preprocess(inputs, freq) + # Move tensors to the same device as input + input_ts = input_ts.to(device) + input_padding = input_padding.to(device) + inp_freq = inp_freq.to(device) + mean_outputs, full_outputs, last_hidden_state, all_attentions, all_hidden_states = self.decode( input_ts=input_ts, paddings=input_padding, diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index ea734402c1e5..834748218f2b 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -20,8 +20,9 @@ import numpy as np import torch +from huggingface_hub import hf_hub_download from transformers import TimesFMConfig, is_torch_available -from transformers.testing_utils import require_torch, torch_device +from transformers.testing_utils import require_torch, slow, torch_device from transformers.utils import is_torch_fx_available from ...test_configuration_common import ConfigTester @@ -32,7 +33,9 @@ pass if is_torch_available(): - from transformers import TimesFMModelForPrediction + from transformers import TimesFMDecoder, TimesFMModelForPrediction + +TOLERANCE = 1e-4 class TimesFMModelTester: @@ -46,7 +49,7 @@ def __init__( num_layers: int = 1, model_dim: int = 16, intermediate_size: int = 32, - head_dim: int = 2, + head_dim: int = 8, num_heads: int = 2, tolerance: float = 1e-6, rms_norm_eps: float = 1e-6, @@ -163,3 +166,30 @@ def test_model_main_input_name(self): # The main input is the name of the argument after `self` observed_main_input_name = list(model_signature.parameters.keys())[1] self.assertEqual(TimesFMModelForPrediction.main_input_name, observed_main_input_name) + + +@require_torch +@slow +class TimesFMModelIntegrationTests(unittest.TestCase): + @classmethod + def load_batch(cls, filename="train-batch.pt"): + file = hf_hub_download( + repo_id="hf-internal-testing/tourism-monthly-batch", filename=filename, repo_type="dataset" + ) + batch = torch.load(file, map_location=torch_device) + return batch + + def test_inference_no_head(self): + model = TimesFMModelForPrediction.from_pretrained("huggingface/timesfm-tourism-monthly").to(torch_device) + batch = self.load_batch() + with torch.no_grad(): + inputs = batch["past_values"] + output = model(inputs=inputs).last_hidden_state + self.assertEqual( + output.shape, torch.Size([64, model.config.context_len // model.config.patch_len, model.config.model_dim]) + ) + + expected_slice = torch.tensor( + [[-4.0141, 3.3141, 1.9321], [-4.9121, 3.1443, 2.0836], [-5.1142, 2.7376, 2.1566]], device=torch_device + ) + self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE)) From c46864435a1e79961bba549bc081303d3be987f2 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 10 Dec 2024 10:13:40 +0100 Subject: [PATCH 138/242] fix formatting --- tests/models/timesfm/test_modeling_timesfm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index 834748218f2b..46f758254bbd 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -19,8 +19,8 @@ import numpy as np import torch - from huggingface_hub import hf_hub_download + from transformers import TimesFMConfig, is_torch_available from transformers.testing_utils import require_torch, slow, torch_device from transformers.utils import is_torch_fx_available @@ -33,7 +33,7 @@ pass if is_torch_available(): - from transformers import TimesFMDecoder, TimesFMModelForPrediction + from transformers import TimesFMModelForPrediction TOLERANCE = 1e-4 From 689d2a4514e6040d056c0c8908dbab8a308fc349 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 22 Dec 2024 14:47:57 +0100 Subject: [PATCH 139/242] fix import order --- .../models/timesfm/convert_timesfm_orignal_to_hf.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py index f791fb0236aa..f4256ba653c6 100644 --- a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py @@ -3,10 +3,9 @@ import shutil import numpy as np - +import timesfm import torch -import timesfm from transformers import TimesFMConfig, TimesFMModelForPrediction From abb1c0ae63019571a132831672043cedf354aadf Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 22 Dec 2024 15:01:00 +0100 Subject: [PATCH 140/242] fix _quantile_loss --- src/transformers/models/timesfm/modeling_timesfm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 35cfd0d537b7..c9ca29bf4b44 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -991,8 +991,8 @@ def decode( def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: losses = [] - for q in self.config.quantiles: - errors = targets - predictions + for i, q in enumerate(self.config.quantiles): + errors = targets - predictions[..., i] loss = torch.max((q - 1) * errors, q * errors) losses.append(loss.mean()) return torch.stack(losses).mean() From 686c71bb9357ea2172fe84ac3818d6c2899d083d Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 22 Dec 2024 15:20:47 +0100 Subject: [PATCH 141/242] add doc for SDPA --- docs/source/en/perf_infer_gpu_one.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 7f57a99c7d35..5534a43e20b2 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -314,6 +314,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [SigLIP2](https://huggingface.co/docs/transformers/model_doc/siglip2) * [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel) * [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model) +* [TimesFM](https://huggingface.co/docs/transformers/model_doc/timesfm#transformers.TimesFMModel) * [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) * [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel) * [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel) From 91c50a4aaf4c13a1028abbfa98af2ee0bcfee8ff Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 3 Jan 2025 23:46:44 +0100 Subject: [PATCH 142/242] use timesfm 2.0 --- .../timesfm/convert_timesfm_orignal_to_hf.py | 28 +++++++++++++++---- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py index f4256ba653c6..b48e1b8c7dc8 100644 --- a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py @@ -19,7 +19,7 @@ """ -def write_model(model_path, safe_serialization=True): +def write_model(model_path, safe_serialization=True, huggingface_repo_id="google/timesfm-2.0-500m-pytorch"): os.makedirs(model_path, exist_ok=True) tmp_model_path = os.path.join(model_path, "tmp") os.makedirs(tmp_model_path, exist_ok=True) @@ -29,8 +29,13 @@ def write_model(model_path, safe_serialization=True): backend="cuda" if torch.cuda.is_available() else "cpu", per_core_batch_size=32, horizon_len=128, + input_patch_len=32, + output_patch_len=128, + num_layers=50, + model_dims=1280, + use_positional_embedding=False, ), - checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id="google/timesfm-1.0-200m-pytorch"), + checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=huggingface_repo_id), ) timesfm_config = TimesFMConfig( @@ -42,6 +47,7 @@ def write_model(model_path, safe_serialization=True): intermediate_size=tfm.hparams.model_dims, head_dim=tfm.hparams.model_dims // tfm.hparams.num_heads, num_heads=tfm.hparams.num_heads, + use_positional_embedding=tfm.hparams.use_positional_embedding, ) timesfm_config.save_pretrained(tmp_model_path) timesfm_model = TimesFMModelForPrediction(timesfm_config) @@ -132,7 +138,7 @@ def write_model(model_path, safe_serialization=True): shutil.rmtree(tmp_model_path) -def check_outputs(model_path): +def check_outputs(model_path, huggingface_repo_id): """Compares outputs between original and converted models.""" print("\nChecking model outputs...") @@ -142,8 +148,13 @@ def check_outputs(model_path): backend="cuda" if torch.cuda.is_available() else "cpu", per_core_batch_size=32, horizon_len=128, + input_patch_len=32, + output_patch_len=128, + num_layers=50, + model_dims=1280, + use_positional_embedding=False, ), - checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id="google/timesfm-1.0-200m-pytorch"), + checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=huggingface_repo_id), ) # Load converted model @@ -229,6 +240,12 @@ def main(): parser.add_argument( "--safe_serialization", type=bool, default=True, help="Whether or not to save using `safetensors`." ) + parser.add_argument( + "--huggingface_repo_id", + type=str, + default="google/timesfm-2.0-500m-pytorch", + help="The Hugging Face repository ID to use for the model.", + ) args = parser.parse_args() # if the saved model file exists, skip the conversion @@ -240,8 +257,9 @@ def main(): write_model( model_path=args.output_dir, safe_serialization=args.safe_serialization, + huggingface_repo_id=args.huggingface_repo_id, ) - check_outputs(args.output_dir) + check_outputs(args.output_dir, args.huggingface_repo_id) if __name__ == "__main__": From cef8510317b3bfec7415a587f5028511ee72790c Mon Sep 17 00:00:00 2001 From: Rajat Sen Date: Fri, 3 Jan 2025 23:33:20 +0000 Subject: [PATCH 143/242] bug fix in timesfm decode function. --- src/transformers/models/timesfm/modeling_timesfm.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index c9ca29bf4b44..15a1d901797c 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -900,7 +900,7 @@ def decode( freq: torch.LongTensor, horizon_len: int, output_patch_len: int | None = None, - max_len: int = 512, + max_len: int | None = None, return_forecast_on_context: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, @@ -914,7 +914,8 @@ def decode( horizon_len: prediction length. output_patch_len: output length to be fetched from one step of auto-regressive decoding. - max_len: maximum training context length. + max_len: maximum training context length. If None, then we use the length + of the initial context as max length. return_forecast_on_context: whether to return the model forecast on the context except the first input patch. @@ -940,6 +941,9 @@ def decode( ) if output_patch_len is None: output_patch_len = self.config.horizon_len + if max_len is None: + max_len = context_len + num_decode_patches = (horizon_len + output_patch_len - 1) // output_patch_len for step_index in range(num_decode_patches): current_padding = paddings[:, 0 : final_out.shape[1]] @@ -961,7 +965,8 @@ def decode( # For the first decodings step, collect the model forecast on the # context except the unavailable first input batch forecast. new_full_ts = fprop_outputs[:, :-1, : self.config.patch_len, :] - new_full_ts = fprop_outputs.view(new_full_ts.size(0), -1, new_full_ts.size(3)) + # We have to use reshape and not view for non-contiguous memory + new_full_ts = new_full_ts.reshape(new_full_ts.size(0), -1, new_full_ts.size(3)) full_outputs.append(new_full_ts) From 7c7e56f8f4bb6e2ac432c117173293111df55844 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 6 Jan 2025 14:34:34 +0100 Subject: [PATCH 144/242] compare mean forecasts --- src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py index b48e1b8c7dc8..5309caca77a8 100644 --- a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py @@ -153,6 +153,7 @@ def check_outputs(model_path, huggingface_repo_id): num_layers=50, model_dims=1280, use_positional_embedding=False, + point_forecast_mode="mean", ), checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=huggingface_repo_id), ) From 22bb7cfad053eb4cec84a36190c2645fcf65d59c Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Thu, 9 Jan 2025 17:06:59 -0800 Subject: [PATCH 145/242] refactor type hints, use CamelCase --- docs/source/en/model_doc/timesfm.md | 12 +- docs/source/en/perf_infer_gpu_one.md | 2 +- src/transformers/__init__.py | 16 +- .../models/auto/configuration_auto.py | 4 +- src/transformers/models/auto/modeling_auto.py | 4 +- src/transformers/models/timesfm/__init__.py | 12 +- .../models/timesfm/configuration_timesfm.py | 6 +- .../timesfm/convert_timesfm_orignal_to_hf.py | 8 +- .../models/timesfm/modeling_timesfm.py | 140 +++++++++--------- src/transformers/utils/dummy_pt_objects.py | 6 +- tests/models/timesfm/test_modeling_timesfm.py | 32 ++-- 11 files changed, 121 insertions(+), 121 deletions(-) diff --git a/docs/source/en/model_doc/timesfm.md b/docs/source/en/model_doc/timesfm.md index 4e2ee1ae0c61..3edf1aedbb0a 100644 --- a/docs/source/en/model_doc/timesfm.md +++ b/docs/source/en/model_doc/timesfm.md @@ -30,18 +30,18 @@ This model was contributed by [kashif](https://huggingface.co/kashif). The original code can be found [here](https://github.com/google-research/timesfm). -## TimesFMConfig +## TimesFmConfig -[[autodoc]] TimesFMConfig +[[autodoc]] TimesFmConfig -## TimesFMDecoder +## TimesFmDecoder -[[autodoc]] TimesFMDecoder +[[autodoc]] TimesFmDecoder - forward -## TimesFMModelForPrediction +## TimesFmModelForPrediction -[[autodoc]] TimesFMModelForPrediction +[[autodoc]] TimesFmModelForPrediction - forward diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 5534a43e20b2..508d7fe3f579 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -314,7 +314,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [SigLIP2](https://huggingface.co/docs/transformers/model_doc/siglip2) * [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel) * [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model) -* [TimesFM](https://huggingface.co/docs/transformers/model_doc/timesfm#transformers.TimesFMModel) +* [TimesFM](https://huggingface.co/docs/transformers/model_doc/timesfm#transformers.TimesFmModel) * [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) * [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel) * [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9ac362efb0a4..55146c35e7fe 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -820,7 +820,7 @@ ], "models.textnet": ["TextNetConfig"], "models.time_series_transformer": ["TimeSeriesTransformerConfig"], - "models.timesfm": ["TimesFMConfig"], + "models.timesfm": ["TimesFmConfig"], "models.timesformer": ["TimesformerConfig"], "models.timm_backbone": ["TimmBackboneConfig"], "models.timm_wrapper": ["TimmWrapperConfig"], @@ -3742,9 +3742,9 @@ ) _import_structure["models.timesfm"].extend( [ - "TimesFMDecoder", - "TimesFMModelForPrediction", - "TimesFMPreTrainedModel", + "TimesFmDecoder", + "TimesFmModelForPrediction", + "TimesFmPreTrainedModel", ] ) _import_structure["models.timesformer"].extend( @@ -6019,7 +6019,7 @@ from .models.time_series_transformer import ( TimeSeriesTransformerConfig, ) - from .models.timesfm import TimesFMConfig + from .models.timesfm import TimesFmConfig from .models.timesformer import ( TimesformerConfig, ) @@ -8455,9 +8455,9 @@ TimeSeriesTransformerPreTrainedModel, ) from .models.timesfm import ( - TimesFMDecoder, - TimesFMModelForPrediction, - TimesFMPreTrainedModel, + TimesFmDecoder, + TimesFmModelForPrediction, + TimesFmPreTrainedModel, ) from .models.timesformer import ( TimesformerForVideoClassification, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index c5c8eab382a9..bf75c8b80e9e 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -295,7 +295,7 @@ ("tapas", "TapasConfig"), ("textnet", "TextNetConfig"), ("time_series_transformer", "TimeSeriesTransformerConfig"), - ("timesfm", "TimesFMConfig"), + ("timesfm", "TimesFmConfig"), ("timesformer", "TimesformerConfig"), ("timm_backbone", "TimmBackboneConfig"), ("timm_wrapper", "TimmWrapperConfig"), @@ -646,7 +646,7 @@ ("tapex", "TAPEX"), ("textnet", "TextNet"), ("time_series_transformer", "Time Series Transformer"), - ("timesfm", "TimesFM"), + ("timesfm", "TimesFm"), ("timesformer", "TimeSformer"), ("timm_backbone", "TimmBackbone"), ("timm_wrapper", "TimmWrapperModel"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 4ec04c736771..0c93ccc04ae9 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -271,7 +271,7 @@ ("tapas", "TapasModel"), ("textnet", "TextNetModel"), ("time_series_transformer", "TimeSeriesTransformerModel"), - ("timesfm", "TimesFMModelForPrediction"), + ("timesfm", "TimesFmModelForPrediction"), ("timesformer", "TimesformerModel"), ("timm_backbone", "TimmBackbone"), ("timm_wrapper", "TimmWrapperModel"), @@ -381,7 +381,7 @@ ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), ("tapas", "TapasForMaskedLM"), - ("timesfm", "TimesFMModelForPrediction"), + ("timesfm", "TimesFmModelForPrediction"), ("transfo-xl", "TransfoXLLMHeadModel"), ("tvlt", "TvltForPreTraining"), ("unispeech", "UniSpeechForPreTraining"), diff --git a/src/transformers/models/timesfm/__init__.py b/src/transformers/models/timesfm/__init__.py index 51028a860782..0441d2ce1eda 100644 --- a/src/transformers/models/timesfm/__init__.py +++ b/src/transformers/models/timesfm/__init__.py @@ -21,7 +21,7 @@ ) -_import_structure = {"configuration_timesfm": ["TimesFMConfig", "TimesFMOnnxConfig"]} +_import_structure = {"configuration_timesfm": ["TimesFmConfig", "TimesFmOnnxConfig"]} try: if not is_torch_available(): @@ -30,13 +30,13 @@ pass else: _import_structure["modeling_timesfm"] = [ - "TimesFMModelForPrediction", - "TimesFMDecoder", - "TimesFMPreTrainedModel", + "TimesFmModelForPrediction", + "TimesFmDecoder", + "TimesFmPreTrainedModel", ] if TYPE_CHECKING: - from .configuration_timesfm import TimesFMConfig, TimesFMOnnxConfig + from .configuration_timesfm import TimesFmConfig, TimesFmOnnxConfig try: if not is_torch_available(): @@ -44,7 +44,7 @@ except OptionalDependencyNotAvailable: pass else: - from .modeling_timesfm import TimesFMDecoder, TimesFMModelForPrediction, TimesFMPreTrainedModel + from .modeling_timesfm import TimesFmDecoder, TimesFmModelForPrediction, TimesFmPreTrainedModel else: import sys diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 012315882957..62e4920ca14a 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -24,9 +24,9 @@ logger = logging.get_logger(__name__) -class TimesFMConfig(PretrainedConfig): +class TimesFmConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`TimesFMModelForPrediction`] or a [`TFTimesFMDecoder`]. It is used to + This is the configuration class to store the configuration of a [`TimesFmModelForPrediction`] or a [`TFTimesFmDecoder`]. It is used to instantiate a TimesFM model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the TimesFM [google/timesfm-1.0-200m](https://huggingface.co/google/timesfm-1.0-200m) architecture. @@ -123,7 +123,7 @@ def __init__( ) -class TimesFMOnnxConfig(OnnxSeq2SeqConfigWithPast): +class TimesFmOnnxConfig(OnnxSeq2SeqConfigWithPast): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: common_inputs = { diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py index 5309caca77a8..136efdd7d3e9 100644 --- a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py @@ -6,7 +6,7 @@ import timesfm import torch -from transformers import TimesFMConfig, TimesFMModelForPrediction +from transformers import TimesFmConfig, TimesFmModelForPrediction """ @@ -38,7 +38,7 @@ def write_model(model_path, safe_serialization=True, huggingface_repo_id="google checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=huggingface_repo_id), ) - timesfm_config = TimesFMConfig( + timesfm_config = TimesFmConfig( patch_len=tfm.hparams.input_patch_len, context_len=tfm.hparams.context_len, horizon_len=tfm.hparams.horizon_len, @@ -50,7 +50,7 @@ def write_model(model_path, safe_serialization=True, huggingface_repo_id="google use_positional_embedding=tfm.hparams.use_positional_embedding, ) timesfm_config.save_pretrained(tmp_model_path) - timesfm_model = TimesFMModelForPrediction(timesfm_config) + timesfm_model = TimesFmModelForPrediction(timesfm_config) # copy the weights from the original model to the new model making original_model = tfm._model @@ -159,7 +159,7 @@ def check_outputs(model_path, huggingface_repo_id): ) # Load converted model - converted_model = TimesFMModelForPrediction.from_pretrained( + converted_model = TimesFmModelForPrediction.from_pretrained( model_path, torch_dtype=torch.bfloat16, attn_implementation="sdpa", diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 15a1d901797c..c4b3ddac9d7a 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -17,7 +17,7 @@ import logging import math from dataclasses import dataclass -from typing import List, Sequence, Tuple +from typing import List, Sequence, Tuple, Optional, Union import torch import torch.nn as nn @@ -25,23 +25,23 @@ from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel -from .configuration_timesfm import TimesFMConfig +from .configuration_timesfm import TimesFmConfig @dataclass -class TimesFMDecoderOutput(BaseModelOutput): - loc: torch.Tensor | None = None - scale: torch.Tensor | None = None +class TimesFmDecoderOutput(BaseModelOutput): + loc: Optional[torch.Tensor] = None + scale: Optional[torch.Tensor] = None @dataclass -class TimesFMOutputForPrediction(BaseModelOutput): - mean_predictions: torch.Tensor | None = None - full_predictions: torch.Tensor | None = None - loss: torch.Tensor | float | None = None +class TimesFmOutputForPrediction(BaseModelOutput): + mean_predictions: Optional[torch.Tensor] = None + full_predictions: Optional[torch.Tensor] = None + loss: Optional[Union[torch.Tensor, float]] = None -class TimesFMTransformerMLP(nn.Module): +class TimesFmTransformerMLP(nn.Module): """Pax transformer MLP in pytorch.""" def __init__( @@ -64,7 +64,7 @@ def forward(self, x, paddings=None): return outputs + x -class TimesFMResidualBlock(nn.Module): +class TimesFmResidualBlock(nn.Module): """TimesFM residual block.""" def __init__(self, input_dims, hidden_dims, output_dims): @@ -91,7 +91,7 @@ def forward(self, x): return output + residual -class TimesFMRMSNorm(torch.nn.Module): +class TimesFmRMSNorm(torch.nn.Module): """Pax rms norm in pytorch.""" def __init__( @@ -117,7 +117,7 @@ def forward(self, x): return output.type_as(x) -class TimesFMPositionalEmbedding(nn.Module): +class TimesFmPositionalEmbedding(nn.Module): """Generates position embedding for a given 1-d sequence. Attributes: @@ -172,10 +172,10 @@ def forward(self, seq_length=None, position=None): return signal -class TimesFMAttention(nn.Module): - """Implements the attention used in TimesFM. One key diffrence is that there is _per_dim_scaling of the query.""" +class TimesFmAttention(nn.Module): + """Implements the attention used in TimesFM. One key difference is that there is _per_dim_scaling of the query.""" - def __init__(self, config: TimesFMConfig): + def __init__(self, config: TimesFmConfig): super().__init__() self.attention_dropout = config.attention_dropout @@ -205,11 +205,11 @@ def _scale_query(self, query: torch.Tensor) -> torch.Tensor: def forward( self, hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - kv_write_indices: torch.Tensor | None = None, - kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: Optional[torch.Tensor] = None, + kv_write_indices: Optional[torch.Tensor] = None, + kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, output_attentions: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states_shape = hidden_states.shape assert len(hidden_states_shape) == 3 @@ -268,17 +268,17 @@ def forward( return output, scores -class TimesFMSdpaAttention(TimesFMAttention): +class TimesFmSdpaAttention(TimesFmAttention): """TimesFM attention implementation using torch.nn.functional.scaled_dot_product_attention.""" def forward( self, hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - kv_write_indices: torch.Tensor | None = None, - kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: Optional[torch.Tensor] = None, + kv_write_indices: Optional[torch.Tensor] = None, + kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, output_attentions: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: if output_attentions: return super().forward( hidden_states=hidden_states, @@ -330,7 +330,7 @@ def forward( value = value.contiguous() # Run scaled dot-product attention - # Note: attention_mask should already be in the correct format from TimesFMStackedDecoder + # Note: attention_mask should already be in the correct format from TimesFmStackedDecoder attn_output = F.scaled_dot_product_attention( query, key, @@ -349,15 +349,15 @@ def forward( TIMESFM_ATTENTION_CLASSES = { - "eager": TimesFMAttention, - "sdpa": TimesFMSdpaAttention, + "eager": TimesFmAttention, + "sdpa": TimesFmSdpaAttention, } -class TimesFMDecoderLayer(nn.Module): +class TimesFmDecoderLayer(nn.Module): """Transformer layer.""" - def __init__(self, config: TimesFMConfig): + def __init__(self, config: TimesFmConfig): super().__init__() if config._attn_implementation not in TIMESFM_ATTENTION_CLASSES: @@ -365,16 +365,16 @@ def __init__(self, config: TimesFMConfig): attention_class = TIMESFM_ATTENTION_CLASSES[config._attn_implementation] self.self_attn = attention_class(config) - self.mlp = TimesFMTransformerMLP(config.model_dim, config.intermediate_size) - self.input_layernorm = TimesFMRMSNorm(config.model_dim, eps=config.rms_norm_eps) + self.mlp = TimesFmTransformerMLP(config.model_dim, config.intermediate_size) + self.input_layernorm = TimesFmRMSNorm(config.model_dim, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, paddings: torch.Tensor, - kv_write_indices: torch.Tensor | None = None, - kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, + kv_write_indices: Optional[torch.Tensor] = None, + kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, output_attentions: bool = False, ) -> torch.Tensor: # Self Attention @@ -395,20 +395,20 @@ def forward( return scores, hidden_states -class TimesFMStackedDecoder(nn.Module): +class TimesFmStackedDecoder(nn.Module): """Stacked transformer layer.""" - def __init__(self, config: TimesFMConfig): + def __init__(self, config: TimesFmConfig): super().__init__() - self.layers = nn.ModuleList([TimesFMDecoderLayer(config) for _ in range(config.num_layers)]) + self.layers = nn.ModuleList([TimesFmDecoderLayer(config) for _ in range(config.num_layers)]) def forward( self, hidden_states: torch.Tensor, paddings: torch.Tensor, - kv_write_indices: torch.Tensor | None = None, - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] | None = None, + kv_write_indices: Optional[torch.Tensor] = None, + kv_caches: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, output_attentions: bool = False, output_hidden_states: bool = False, ) -> BaseModelOutput: @@ -613,10 +613,10 @@ def expand_t(key_mask): return torch.minimum(a, b) # Element-wise minimum, similar to jnp.minimum -class TimesFMPreTrainedModel(PreTrainedModel): +class TimesFmPreTrainedModel(PreTrainedModel): """handles the loading for all models.""" - config_class = TimesFMConfig + config_class = TimesFmConfig base_model_prefix = "timesfm" main_input_name = "inputs" _supports_sdpa = True @@ -634,10 +634,10 @@ def _init_weights(self, module): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) - elif isinstance(module, TimesFMRMSNorm): + elif isinstance(module, TimesFmRMSNorm): nn.init.zeros_(module.weight) - elif isinstance(module, TimesFMTransformerMLP): + elif isinstance(module, TimesFmTransformerMLP): # Initialize gate projection module.gate_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor) if module.gate_proj.bias is not None: @@ -652,7 +652,7 @@ def _init_weights(self, module): nn.init.ones_(module.layer_norm.weight) nn.init.zeros_(module.layer_norm.bias) - elif isinstance(module, TimesFMAttention): + elif isinstance(module, TimesFmAttention): # Initialize qkv projection module.qkv_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor) if module.qkv_proj.bias is not None: @@ -666,7 +666,7 @@ def _init_weights(self, module): # Initialize scaling parameter nn.init.ones_(module.scaling) - elif isinstance(module, TimesFMResidualBlock): + elif isinstance(module, TimesFmResidualBlock): # Initialize hidden layer module.hidden_layer[0].weight.data.normal_(mean=0, std=self.config.initializer_factor) if module.hidden_layer[0].bias is not None: @@ -682,26 +682,26 @@ def _init_weights(self, module): if module.residual_layer.bias is not None: nn.init.zeros_(module.residual_layer.bias) - elif isinstance(module, TimesFMPositionalEmbedding): + elif isinstance(module, TimesFmPositionalEmbedding): pass -class TimesFMDecoder(TimesFMPreTrainedModel): +class TimesFmDecoder(TimesFmPreTrainedModel): """Patched time-series decoder without any specific output layer.""" - def __init__(self, config: TimesFMConfig): + def __init__(self, config: TimesFmConfig): super().__init__(config) self.config = config - self.input_ff_layer = TimesFMResidualBlock( + self.input_ff_layer = TimesFmResidualBlock( input_dims=2 * config.patch_len, output_dims=config.model_dim, hidden_dims=config.intermediate_size, ) self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.model_dim) - self.stacked_transformer = TimesFMStackedDecoder(config=config) + self.stacked_transformer = TimesFmStackedDecoder(config=config) if self.config.use_positional_embedding: - self.position_emb = TimesFMPositionalEmbedding( + self.position_emb = TimesFmPositionalEmbedding( embedding_dims=self.config.model_dim, ) @@ -772,7 +772,7 @@ def forward( output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, - ) -> TimesFMDecoderOutput | tuple[torch.Tensor, ...]: + ) -> Union[TimesFmDecoderOutput, tuple[torch.Tensor, ...]]: model_input, patched_padding, stats, _ = self._preprocess_input( input_ts=input_ts, input_padding=input_padding, @@ -792,7 +792,7 @@ def forward( all_hidden_states = None if return_dict: - return TimesFMDecoderOutput( + return TimesFmDecoderOutput( last_hidden_state=transformer_output.last_hidden_state, hidden_states=all_hidden_states, attentions=transformer_output.attentions if output_attentions else None, @@ -809,20 +809,20 @@ def forward( ) -class TimesFMModelForPrediction(TimesFMPreTrainedModel): +class TimesFmModelForPrediction(TimesFmPreTrainedModel): """TimesFM model for quantile and mean prediction.""" - def __init__(self, config: TimesFMConfig): + def __init__(self, config: TimesFmConfig): super().__init__(config) self.config = config self.context_len = config.context_len self.horizon_len = config.horizon_len - self.decoder = TimesFMDecoder(config) + self.decoder = TimesFmDecoder(config) # quantile and mean output - self.horizon_ff_layer = TimesFMResidualBlock( + self.horizon_ff_layer = TimesFmResidualBlock( input_dims=config.model_dim, output_dims=config.horizon_len * (1 + len(config.quantiles)), hidden_dims=config.intermediate_size, @@ -899,8 +899,8 @@ def decode( paddings: torch.Tensor, freq: torch.LongTensor, horizon_len: int, - output_patch_len: int | None = None, - max_len: int | None = None, + output_patch_len: Optional[int] = None, + max_len: Optional[int] = None, return_forecast_on_context: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, @@ -1005,16 +1005,16 @@ def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> to def forward( self, inputs: Sequence[torch.Tensor], - freq: Sequence[torch.Tensor | int] | None = None, - window_size: int | None = None, - future_target: torch.Tensor | None = None, - forecast_context_len: int | None = None, + freq: Optional[Sequence[Union[torch.Tensor,int]]] = None, + window_size: Optional[int] = None, + future_target: Optional[torch.Tensor] = None, + forecast_context_len: Optional[int] = None, return_forecast_on_context: bool = False, truncate_negative: bool = False, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - ) -> TimesFMOutputForPrediction | tuple[torch.Tensor, ...]: + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[TimesFmOutputForPrediction, tuple[torch.Tensor, ...]]: """Forecasts on a list of time series. Args: @@ -1032,10 +1032,10 @@ def forward( have non-negative values. output_attentions: Whether to return the attentions. output_hidden_states: Whether to return the hidden states. - return_dict: Whether to return a TimesFMOutputForPrediction object. + return_dict: Whether to return a TimesFmOutputForPrediction object. Returns: - A TimesFMOutputForPrediction object containing: + A TimesFmOutputForPrediction object containing: - the mean forecast of size (# inputs, # forecast horizon), - the full forecast (mean + quantiles) of size (# inputs, # forecast horizon, 1 + # quantiles). @@ -1109,7 +1109,7 @@ def forward( loss = mse_loss + quantile_loss if return_dict: - return TimesFMOutputForPrediction( + return TimesFmOutputForPrediction( last_hidden_state=last_hidden_state, attentions=all_attentions if output_attentions else None, hidden_states=all_hidden_states if output_hidden_states else None, diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 78e387d4f6cc..3e9a26d2b138 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -9487,21 +9487,21 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class TimesFMDecoder(metaclass=DummyObject): +class TimesFmDecoder(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class TimesFMModelForPrediction(metaclass=DummyObject): +class TimesFmModelForPrediction(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class TimesFMPreTrainedModel(metaclass=DummyObject): +class TimesFmPreTrainedModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index 46f758254bbd..c38b1ab07dd0 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2024 Google TimesFM Authors and HuggingFace Inc. team. +# Copyright 2024 Google LLC and HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ import torch from huggingface_hub import hf_hub_download -from transformers import TimesFMConfig, is_torch_available +from transformers import TimesFmConfig, is_torch_available from transformers.testing_utils import require_torch, slow, torch_device from transformers.utils import is_torch_fx_available @@ -33,12 +33,12 @@ pass if is_torch_available(): - from transformers import TimesFMModelForPrediction + from transformers import TimesFmModelForPrediction TOLERANCE = 1e-4 -class TimesFMModelTester: +class TimesFmModelTester: def __init__( self, parent, @@ -84,10 +84,10 @@ def __init__( self.hidden_size = model_dim def get_large_model_config(self): - return TimesFMConfig.from_pretrained("google/timesfm-1.0-200m-pytorch") + return TimesFmConfig.from_pretrained("google/timesfm-1.0-200m-pytorch") def get_config(self): - return TimesFMConfig( + return TimesFmConfig( patch_len=self.patch_len, context_len=self.context_len, horizon_len=self.horizon_len, @@ -129,9 +129,9 @@ def prepare_config_and_inputs_for_common(self): @require_torch -class TimesFMModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (TimesFMModelForPrediction,) if is_torch_available() else () - all_generative_model_classes = (TimesFMModelForPrediction,) if is_torch_available() else () +class TimesFmModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (TimesFmModelForPrediction,) if is_torch_available() else () + all_generative_model_classes = (TimesFmModelForPrediction,) if is_torch_available() else () all_parallelizable_model_classes = () fx_compatible = False test_pruning = False @@ -141,12 +141,12 @@ class TimesFMModelTest(ModelTesterMixin, unittest.TestCase): test_inputs_embeds = False def setUp(self): - self.model_tester = TimesFMModelTester(self) - self.config_tester = ConfigTester(self, config_class=TimesFMConfig) + self.model_tester = TimesFmModelTester(self) + self.config_tester = ConfigTester(self, config_class=TimesFmConfig) def test_create_and_run_model(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model = TimesFMModelForPrediction(config) + model = TimesFmModelForPrediction(config) model.to(torch_device) model.eval() results = model(**inputs_dict) @@ -162,15 +162,15 @@ def test_headmasking(self): # the main input name is `inputs` def test_model_main_input_name(self): - model_signature = inspect.signature(getattr(TimesFMModelForPrediction, "forward")) + model_signature = inspect.signature(getattr(TimesFmModelForPrediction, "forward")) # The main input is the name of the argument after `self` observed_main_input_name = list(model_signature.parameters.keys())[1] - self.assertEqual(TimesFMModelForPrediction.main_input_name, observed_main_input_name) + self.assertEqual(TimesFmModelForPrediction.main_input_name, observed_main_input_name) @require_torch @slow -class TimesFMModelIntegrationTests(unittest.TestCase): +class TimesFmModelIntegrationTests(unittest.TestCase): @classmethod def load_batch(cls, filename="train-batch.pt"): file = hf_hub_download( @@ -180,7 +180,7 @@ def load_batch(cls, filename="train-batch.pt"): return batch def test_inference_no_head(self): - model = TimesFMModelForPrediction.from_pretrained("huggingface/timesfm-tourism-monthly").to(torch_device) + model = TimesFmModelForPrediction.from_pretrained("huggingface/timesfm-tourism-monthly").to(torch_device) batch = self.load_batch() with torch.no_grad(): inputs = batch["past_values"] From 53b290aef78a8b39f767fdbb91687506fb9f996c Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Mon, 13 Jan 2025 19:12:57 -0800 Subject: [PATCH 146/242] consolidate decode func --- .../models/timesfm/modeling_timesfm.py | 176 ++++++------------ 1 file changed, 59 insertions(+), 117 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index c4b3ddac9d7a..c8e3c28589e6 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -893,107 +893,6 @@ def _reverse_transform(self, outputs: torch.Tensor, stats: tuple[torch.Tensor, t mu, sigma = stats return outputs * sigma[:, None, None, None] + mu[:, None, None, None] - def decode( - self, - input_ts: torch.Tensor, - paddings: torch.Tensor, - freq: torch.LongTensor, - horizon_len: int, - output_patch_len: Optional[int] = None, - max_len: Optional[int] = None, - return_forecast_on_context: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - ) -> tuple[torch.Tensor, ...]: - """Auto-regressive decoding without caching. - - Args: - input_ts: input time-series and paddings. Time-series shape B x C. - paddings: padding shape B x (C + H) where H is the prediction length. - freq: frequency shape B x 1 - horizon_len: prediction length. - output_patch_len: output length to be fetched from one step of - auto-regressive decoding. - max_len: maximum training context length. If None, then we use the length - of the initial context as max length. - return_forecast_on_context: whether to return the model forecast on the - context except the first input patch. - - Returns: - Tuple of two forecasting results: - - Point (mean) output predictions as a tensor with shape B x H'. - - Full predictions (mean and quantiles) as a tensor with shape - B x H' x (1 + # quantiles). - In particular, if return_forecast_on_context is True, H' is H plus - the forecastable context length, i.e. context_len - (first) patch_len. - - Raises: - ValueError: If the paddings do not match the input + horizon_len. - """ - final_out = input_ts - context_len = final_out.shape[1] - full_outputs = [] - - if paddings.shape[1] != final_out.shape[1] + horizon_len: - raise ValueError( - "Length of paddings must match length of input + horizon_len:" - f" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}" - ) - if output_patch_len is None: - output_patch_len = self.config.horizon_len - if max_len is None: - max_len = context_len - - num_decode_patches = (horizon_len + output_patch_len - 1) // output_patch_len - for step_index in range(num_decode_patches): - current_padding = paddings[:, 0 : final_out.shape[1]] - input_ts = final_out[:, -max_len:] - input_padding = current_padding[:, -max_len:] - decoder_output = self.decoder( - input_ts, - input_padding, - freq, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - fprop_outputs = self._postprocess_output( - decoder_output.last_hidden_state, - (decoder_output.loc, decoder_output.scale), - ) - - if return_forecast_on_context and step_index == 0: - # For the first decodings step, collect the model forecast on the - # context except the unavailable first input batch forecast. - new_full_ts = fprop_outputs[:, :-1, : self.config.patch_len, :] - # We have to use reshape and not view for non-contiguous memory - new_full_ts = new_full_ts.reshape(new_full_ts.size(0), -1, new_full_ts.size(3)) - - full_outputs.append(new_full_ts) - - # (full batch, last patch, output_patch_len, index of mean forecast = 0) - new_ts = fprop_outputs[:, -1, :output_patch_len, 0] - new_full_ts = fprop_outputs[:, -1, :output_patch_len, :] - # (full batch, last patch, output_patch_len, all output indices) - full_outputs.append(new_full_ts) - final_out = torch.concatenate([final_out, new_ts], axis=-1) - - if return_forecast_on_context: - # `full_outputs` indexing starts at after the first input patch. - full_outputs = torch.concatenate(full_outputs, axis=1)[ - :, : (context_len - self.config.patch_len + horizon_len), : - ] - else: - # `full_outputs` indexing starts at the forecast horizon. - full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0:horizon_len, :] - - return ( - full_outputs[:, :, 0], - full_outputs, - decoder_output.last_hidden_state, - decoder_output.attentions, - decoder_output.hidden_states, - ) - def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: losses = [] for i, q in enumerate(self.config.quantiles): @@ -1083,18 +982,61 @@ def forward( input_ts = input_ts.to(device) input_padding = input_padding.to(device) inp_freq = inp_freq.to(device) + + final_out = input_ts + context_len = final_out.shape[1] + full_outputs = [] - mean_outputs, full_outputs, last_hidden_state, all_attentions, all_hidden_states = self.decode( - input_ts=input_ts, - paddings=input_padding, - freq=inp_freq, - horizon_len=self.horizon_len, - return_forecast_on_context=return_forecast_on_context, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - max_len=fcontext_len, - ) + if input_padding.shape[1] != final_out.shape[1] + self.horizon_len: + raise ValueError( + "Length of paddings must match length of input + horizon_len:" + f" {input_padding.shape[1]} != {final_out.shape[1]} + {self.horizon_len}" + ) + output_patch_len = self.config.horizon_len + + num_decode_patches = (self.horizon_len + output_patch_len - 1) // output_patch_len + for step_index in range(num_decode_patches): + current_padding = input_padding[:, 0 : final_out.shape[1]] + input_ts = final_out[:, -fcontext_len:] + input_padding = current_padding[:, -fcontext_len:] + decoder_output = self.decoder( + input_ts, + input_padding, + inp_freq, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + fprop_outputs = self._postprocess_output( + decoder_output.last_hidden_state, + (decoder_output.loc, decoder_output.scale), + ) + + if return_forecast_on_context and step_index == 0: + # For the first decodings step, collect the model forecast on the + # context except the unavailable first input batch forecast. + new_full_ts = fprop_outputs[:, :-1, : self.config.patch_len, :] + # We have to use reshape and not view for non-contiguous memory + new_full_ts = new_full_ts.reshape(new_full_ts.size(0), -1, new_full_ts.size(3)) + + full_outputs.append(new_full_ts) + + # (full batch, last patch, output_patch_len, index of mean forecast = 0) + new_ts = fprop_outputs[:, -1, :output_patch_len, 0] + new_full_ts = fprop_outputs[:, -1, :output_patch_len, :] + # (full batch, last patch, output_patch_len, all output indices) + full_outputs.append(new_full_ts) + final_out = torch.concatenate([final_out, new_ts], axis=-1) + + if return_forecast_on_context: + # `full_outputs` indexing starts at after the first input patch. + full_outputs = torch.concatenate(full_outputs, axis=1)[ + :, : (context_len - self.config.patch_len + self.horizon_len), : + ] + else: + # `full_outputs` indexing starts at the forecast horizon. + full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0:self.horizon_len, :] + mean_outputs = full_outputs[:, :, 0] if window_size is not None: mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] @@ -1110,18 +1052,18 @@ def forward( if return_dict: return TimesFmOutputForPrediction( - last_hidden_state=last_hidden_state, - attentions=all_attentions if output_attentions else None, - hidden_states=all_hidden_states if output_hidden_states else None, + last_hidden_state=decoder_output.last_hidden_state, + attentions=decoder_output.all_attentions if output_attentions else None, + hidden_states=decoder_output.all_hidden_states if output_hidden_states else None, mean_predictions=mean_outputs, full_predictions=full_outputs, loss=loss, ) else: - return_tuple = [last_hidden_state] + return_tuple = [decoder_output.last_hidden_state] if output_hidden_states: - return_tuple.append(all_hidden_states) + return_tuple.append(decoder_output.all_hidden_states) if output_attentions: - return_tuple.append(all_attentions) + return_tuple.append(decoder_output.all_attentions) return_tuple += [mean_outputs, full_outputs, loss] return tuple(return_tuple) From c65e4b46dbb40c2af9bb596c7979c58f01d3a1a9 Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Wed, 5 Feb 2025 18:38:12 -0800 Subject: [PATCH 147/242] more readable code for weight conversion --- docs/source/en/model_doc/timesfm.md | 3 - src/transformers/models/auto/modeling_auto.py | 2 +- .../timesfm/convert_timesfm_orignal_to_hf.py | 151 +++++++++--------- 3 files changed, 73 insertions(+), 83 deletions(-) diff --git a/docs/source/en/model_doc/timesfm.md b/docs/source/en/model_doc/timesfm.md index 3edf1aedbb0a..144b29769faf 100644 --- a/docs/source/en/model_doc/timesfm.md +++ b/docs/source/en/model_doc/timesfm.md @@ -43,6 +43,3 @@ The original code can be found [here](https://github.com/google-research/timesfm [[autodoc]] TimesFmModelForPrediction - forward - - - diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 0c93ccc04ae9..97b6944a1dec 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -271,7 +271,7 @@ ("tapas", "TapasModel"), ("textnet", "TextNetModel"), ("time_series_transformer", "TimeSeriesTransformerModel"), - ("timesfm", "TimesFmModelForPrediction"), + ("timesfm", "TimesFmModel"), ("timesformer", "TimesformerModel"), ("timm_backbone", "TimmBackbone"), ("timm_wrapper", "TimmWrapperModel"), diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py index 136efdd7d3e9..d3db52aa8dbf 100644 --- a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py @@ -1,6 +1,7 @@ import argparse import os import shutil +import re import numpy as np import timesfm @@ -19,6 +20,19 @@ """ +def get_nested_attr(obj, key): + """Recursively retrieves an attribute from an object, handling list/tuple indexing if present.""" + parts = key.split('.') + for part in parts: + match = re.match(r"(.*)\[(\d+)\]", part) # Handle list indexing like `layers[0]` + if match: + attr_name, index = match.groups() + obj = getattr(obj, attr_name)[int(index)] # Access list/tuple element + else: + obj = getattr(obj, part) # Regular attribute access + return obj + + def write_model(model_path, safe_serialization=True, huggingface_repo_id="google/timesfm-2.0-500m-pytorch"): os.makedirs(model_path, exist_ok=True) tmp_model_path = os.path.join(model_path, "tmp") @@ -54,85 +68,64 @@ def write_model(model_path, safe_serialization=True, huggingface_repo_id="google # copy the weights from the original model to the new model making original_model = tfm._model - - # Map decoder input_ff_layer - timesfm_model.decoder.input_ff_layer.hidden_layer[0].weight.data = original_model.input_ff_layer.hidden_layer[ - 0 - ].weight.data - timesfm_model.decoder.input_ff_layer.hidden_layer[0].bias.data = original_model.input_ff_layer.hidden_layer[ - 0 - ].bias.data - timesfm_model.decoder.input_ff_layer.output_layer.weight.data = ( - original_model.input_ff_layer.output_layer.weight.data - ) - timesfm_model.decoder.input_ff_layer.output_layer.bias.data = original_model.input_ff_layer.output_layer.bias.data - timesfm_model.decoder.input_ff_layer.residual_layer.weight.data = ( - original_model.input_ff_layer.residual_layer.weight.data - ) - timesfm_model.decoder.input_ff_layer.residual_layer.bias.data = ( - original_model.input_ff_layer.residual_layer.bias.data - ) - - # Map freq embedding - timesfm_model.decoder.freq_emb.weight.data = original_model.freq_emb.weight.data - - # Map horizon_ff_layer - timesfm_model.horizon_ff_layer.hidden_layer[0].weight.data = original_model.horizon_ff_layer.hidden_layer[ - 0 - ].weight.data - timesfm_model.horizon_ff_layer.hidden_layer[0].bias.data = original_model.horizon_ff_layer.hidden_layer[ - 0 - ].bias.data - timesfm_model.horizon_ff_layer.output_layer.weight.data = original_model.horizon_ff_layer.output_layer.weight.data - timesfm_model.horizon_ff_layer.output_layer.bias.data = original_model.horizon_ff_layer.output_layer.bias.data - timesfm_model.horizon_ff_layer.residual_layer.weight.data = ( - original_model.horizon_ff_layer.residual_layer.weight.data - ) - timesfm_model.horizon_ff_layer.residual_layer.bias.data = original_model.horizon_ff_layer.residual_layer.bias.data - - # Map transformer layers - for i in range(len(timesfm_model.decoder.stacked_transformer.layers)): - # Map attention layers - timesfm_model.decoder.stacked_transformer.layers[ - i - ].self_attn.qkv_proj.weight.data = original_model.stacked_transformer.layers[i].self_attn.qkv_proj.weight.data - timesfm_model.decoder.stacked_transformer.layers[ - i - ].self_attn.qkv_proj.bias.data = original_model.stacked_transformer.layers[i].self_attn.qkv_proj.bias.data - timesfm_model.decoder.stacked_transformer.layers[ - i - ].self_attn.o_proj.weight.data = original_model.stacked_transformer.layers[i].self_attn.o_proj.weight.data - timesfm_model.decoder.stacked_transformer.layers[ - i - ].self_attn.o_proj.bias.data = original_model.stacked_transformer.layers[i].self_attn.o_proj.bias.data - timesfm_model.decoder.stacked_transformer.layers[ - i - ].self_attn.scaling.data = original_model.stacked_transformer.layers[i].self_attn.scaling.data - - # Map MLP layers - timesfm_model.decoder.stacked_transformer.layers[ - i - ].mlp.gate_proj.weight.data = original_model.stacked_transformer.layers[i].mlp.gate_proj.weight.data - timesfm_model.decoder.stacked_transformer.layers[ - i - ].mlp.gate_proj.bias.data = original_model.stacked_transformer.layers[i].mlp.gate_proj.bias.data - timesfm_model.decoder.stacked_transformer.layers[ - i - ].mlp.down_proj.weight.data = original_model.stacked_transformer.layers[i].mlp.down_proj.weight.data - timesfm_model.decoder.stacked_transformer.layers[ - i - ].mlp.down_proj.bias.data = original_model.stacked_transformer.layers[i].mlp.down_proj.bias.data - timesfm_model.decoder.stacked_transformer.layers[ - i - ].mlp.layer_norm.weight.data = original_model.stacked_transformer.layers[i].mlp.layer_norm.weight.data - timesfm_model.decoder.stacked_transformer.layers[ - i - ].mlp.layer_norm.bias.data = original_model.stacked_transformer.layers[i].mlp.layer_norm.bias.data - - # Map layer norms - timesfm_model.decoder.stacked_transformer.layers[ - i - ].input_layernorm.weight.data = original_model.stacked_transformer.layers[i].input_layernorm.weight.data + + # mapping of the layers from the original model to the transformer model + MODEL_LAYER_MAPPING = { + "input_ff_layer.hidden_layer[0].weight": "decoder.input_ff_layer.hidden_layer[0].weight", + "input_ff_layer.hidden_layer[0].bias": "decoder.input_ff_layer.hidden_layer[0].bias", + "input_ff_layer.output_layer.weight": "decoder.input_ff_layer.output_layer.weight", + "input_ff_layer.output_layer.bias": "decoder.input_ff_layer.output_layer.bias", + "input_ff_layer.residual_layer.weight": "decoder.input_ff_layer.residual_layer.weight", + "input_ff_layer.residual_layer.bias": "decoder.input_ff_layer.residual_layer.bias", + + "freq_emb.weight": "decoder.freq_emb.weight", + + "horizon_ff_layer.hidden_layer[0].weight": "horizon_ff_layer.hidden_layer[0].weight", + "horizon_ff_layer.hidden_layer[0].bias": "horizon_ff_layer.hidden_layer[0].bias", + "horizon_ff_layer.output_layer.weight": "horizon_ff_layer.output_layer.weight", + "horizon_ff_layer.output_layer.bias": "horizon_ff_layer.output_layer.bias", + "horizon_ff_layer.residual_layer.weight": "horizon_ff_layer.residual_layer.weight", + "horizon_ff_layer.residual_layer.bias": "horizon_ff_layer.residual_layer.bias", + } + + + TRANSFORMER_LAYER_MAPPING = { + "stacked_transformer.layers[{i}].self_attn.qkv_proj.weight": "decoder.stacked_transformer.layers[{i}].self_attn.qkv_proj.weight", + "stacked_transformer.layers[{i}].self_attn.qkv_proj.bias": "decoder.stacked_transformer.layers[{i}].self_attn.qkv_proj.bias", + "stacked_transformer.layers[{i}].self_attn.o_proj.weight": "decoder.stacked_transformer.layers[{i}].self_attn.o_proj.weight", + "stacked_transformer.layers[{i}].self_attn.o_proj.bias": "decoder.stacked_transformer.layers[{i}].self_attn.o_proj.bias", + "stacked_transformer.layers[{i}].self_attn.scaling": "decoder.stacked_transformer.layers[{i}].self_attn.scaling", + + "stacked_transformer.layers[{i}].mlp.gate_proj.weight": "decoder.stacked_transformer.layers[{i}].mlp.gate_proj.weight", + "stacked_transformer.layers[{i}].mlp.gate_proj.bias": "decoder.stacked_transformer.layers[{i}].mlp.gate_proj.bias", + "stacked_transformer.layers[{i}].mlp.down_proj.weight": "decoder.stacked_transformer.layers[{i}].mlp.down_proj.weight", + "stacked_transformer.layers[{i}].mlp.down_proj.bias": "decoder.stacked_transformer.layers[{i}].mlp.down_proj.bias", + "stacked_transformer.layers[{i}].mlp.layer_norm.weight": "decoder.stacked_transformer.layers[{i}].mlp.layer_norm.weight", + "stacked_transformer.layers[{i}].mlp.layer_norm.bias": "decoder.stacked_transformer.layers[{i}].mlp.layer_norm.bias", + + "stacked_transformer.layers[{i}].input_layernorm.weight": "decoder.stacked_transformer.layers[{i}].input_layernorm.weight", + } + + for old_key, new_key in MODEL_LAYER_MAPPING.items(): + try: + old_attr = get_nested_attr(original_model, old_key) # Get tensor from original model + new_attr = get_nested_attr(timesfm_model, new_key) # Get corresponding attribute in new model + new_attr.data.copy_(old_attr.data) # Copy data + except AttributeError: + print(f"Skipping {old_key} (not found in original model).") + + num_layers = len(timesfm_model.decoder.stacked_transformer.layers) + for i in range(num_layers): + for old_template, new_template in TRANSFORMER_LAYER_MAPPING.items(): + old_key = old_template.format(i=i) + new_key = new_template.format(i=i) + + try: + old_attr = get_nested_attr(original_model, old_key) # Get tensor from original model + new_attr = get_nested_attr(timesfm_model, new_key) # Get corresponding attribute in new model + new_attr.data.copy_(old_attr.data) # Copy data + except AttributeError: + print(f"Skipping {old_key} (not found in original model).") timesfm_model.save_pretrained(model_path, safe_serialization=safe_serialization) shutil.rmtree(tmp_model_path) From b428972d4f7a4fe29c4a884d0bc93bbdb804cc15 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Feb 2025 19:51:35 +0100 Subject: [PATCH 148/242] fix-copies --- docs/source/en/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 3971e67557b1..f3b009dc02c0 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -340,7 +340,7 @@ Flax), PyTorch, and/or TensorFlow. | [TAPEX](model_doc/tapex) | ✅ | ✅ | ✅ | | [TextNet](model_doc/textnet) | ✅ | ❌ | ❌ | | [Time Series Transformer](model_doc/time_series_transformer) | ✅ | ❌ | ❌ | -| [TimesFM](model_doc/timesfm) | ✅ | ❌ | ❌ | +| [TimesFm](model_doc/timesfm) | ✅ | ❌ | ❌ | | [TimeSformer](model_doc/timesformer) | ✅ | ❌ | ❌ | | [TimmWrapperModel](model_doc/timm_wrapper) | ✅ | ❌ | ❌ | | [Trajectory Transformer](model_doc/trajectory_transformer) | ✅ | ❌ | ❌ | From ea05e276c10624f3e97f0fa01bb79eb63627cc4d Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Feb 2025 20:03:57 +0100 Subject: [PATCH 149/242] simpler init --- src/transformers/models/timesfm/__init__.py | 39 ++++--------------- .../models/timesfm/configuration_timesfm.py | 2 +- .../timesfm/convert_timesfm_orignal_to_hf.py | 6 +-- .../models/timesfm/modeling_timesfm.py | 6 +-- 4 files changed, 14 insertions(+), 39 deletions(-) diff --git a/src/transformers/models/timesfm/__init__.py b/src/transformers/models/timesfm/__init__.py index 0441d2ce1eda..12f1541b9c54 100644 --- a/src/transformers/models/timesfm/__init__.py +++ b/src/transformers/models/timesfm/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,42 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import TYPE_CHECKING -from ...utils import ( - OptionalDependencyNotAvailable, - _LazyModule, - is_torch_available, -) - - -_import_structure = {"configuration_timesfm": ["TimesFmConfig", "TimesFmOnnxConfig"]} +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure -try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["modeling_timesfm"] = [ - "TimesFmModelForPrediction", - "TimesFmDecoder", - "TimesFmPreTrainedModel", - ] if TYPE_CHECKING: - from .configuration_timesfm import TimesFmConfig, TimesFmOnnxConfig - - try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .modeling_timesfm import TimesFmDecoder, TimesFmModelForPrediction, TimesFmPreTrainedModel - + from .configuration_timesfm import * + from .modeling_timesfm import * else: import sys - sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 62e4920ca14a..086cfad5c33d 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2024 Google LLC and HuggingFace Inc. team. +# Copyright 2025 Google LLC and HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py index d3db52aa8dbf..c3219d9a6097 100644 --- a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py @@ -1,7 +1,7 @@ import argparse import os -import shutil import re +import shutil import numpy as np import timesfm @@ -68,7 +68,7 @@ def write_model(model_path, safe_serialization=True, huggingface_repo_id="google # copy the weights from the original model to the new model making original_model = tfm._model - + # mapping of the layers from the original model to the transformer model MODEL_LAYER_MAPPING = { "input_ff_layer.hidden_layer[0].weight": "decoder.input_ff_layer.hidden_layer[0].weight", @@ -88,7 +88,7 @@ def write_model(model_path, safe_serialization=True, huggingface_repo_id="google "horizon_ff_layer.residual_layer.bias": "horizon_ff_layer.residual_layer.bias", } - + TRANSFORMER_LAYER_MAPPING = { "stacked_transformer.layers[{i}].self_attn.qkv_proj.weight": "decoder.stacked_transformer.layers[{i}].self_attn.qkv_proj.weight", "stacked_transformer.layers[{i}].self_attn.qkv_proj.bias": "decoder.stacked_transformer.layers[{i}].self_attn.qkv_proj.bias", diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index c8e3c28589e6..dc6a44b5341b 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2024 Google LLC and HuggingFace Inc. team. +# Copyright 2025 Google LLC and HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ import logging import math from dataclasses import dataclass -from typing import List, Sequence, Tuple, Optional, Union +from typing import List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -982,7 +982,7 @@ def forward( input_ts = input_ts.to(device) input_padding = input_padding.to(device) inp_freq = inp_freq.to(device) - + final_out = input_ts context_len = final_out.shape[1] full_outputs = [] From 038859d6162e04a6f035262af57e5033dee47fc5 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Feb 2025 20:07:25 +0100 Subject: [PATCH 150/242] renaem TimesFmMLP --- src/transformers/models/timesfm/modeling_timesfm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index dc6a44b5341b..79a852ab8c01 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -41,8 +41,8 @@ class TimesFmOutputForPrediction(BaseModelOutput): loss: Optional[Union[torch.Tensor, float]] = None -class TimesFmTransformerMLP(nn.Module): - """Pax transformer MLP in pytorch.""" +class TimesFmMLP(nn.Module): + """Pax MLP in pytorch.""" def __init__( self, @@ -365,7 +365,7 @@ def __init__(self, config: TimesFmConfig): attention_class = TIMESFM_ATTENTION_CLASSES[config._attn_implementation] self.self_attn = attention_class(config) - self.mlp = TimesFmTransformerMLP(config.model_dim, config.intermediate_size) + self.mlp = TimesFmMLP(config.model_dim, config.intermediate_size) self.input_layernorm = TimesFmRMSNorm(config.model_dim, eps=config.rms_norm_eps) def forward( @@ -637,7 +637,7 @@ def _init_weights(self, module): elif isinstance(module, TimesFmRMSNorm): nn.init.zeros_(module.weight) - elif isinstance(module, TimesFmTransformerMLP): + elif isinstance(module, TimesFmMLP): # Initialize gate projection module.gate_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor) if module.gate_proj.bias is not None: From ef59621e808f06089effb735fc0d092431f7ea5d Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Feb 2025 20:34:12 +0100 Subject: [PATCH 151/242] use T5LayerNorm --- .../models/timesfm/configuration_timesfm.py | 3 + .../timesfm/convert_timesfm_orignal_to_hf.py | 57 +++++++++---------- .../models/timesfm/modeling_timesfm.py | 46 +++++++-------- 3 files changed, 52 insertions(+), 54 deletions(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 086cfad5c33d..81ce310fe098 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -152,3 +152,6 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: @property def default_onnx_opset(self) -> int: return 13 + + +__all__ = ["TimesFmConfig"] diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py index c3219d9a6097..f1450fda6910 100644 --- a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py @@ -22,7 +22,7 @@ def get_nested_attr(obj, key): """Recursively retrieves an attribute from an object, handling list/tuple indexing if present.""" - parts = key.split('.') + parts = key.split(".") for part in parts: match = re.match(r"(.*)\[(\d+)\]", part) # Handle list indexing like `layers[0]` if match: @@ -71,39 +71,34 @@ def write_model(model_path, safe_serialization=True, huggingface_repo_id="google # mapping of the layers from the original model to the transformer model MODEL_LAYER_MAPPING = { - "input_ff_layer.hidden_layer[0].weight": "decoder.input_ff_layer.hidden_layer[0].weight", - "input_ff_layer.hidden_layer[0].bias": "decoder.input_ff_layer.hidden_layer[0].bias", - "input_ff_layer.output_layer.weight": "decoder.input_ff_layer.output_layer.weight", - "input_ff_layer.output_layer.bias": "decoder.input_ff_layer.output_layer.bias", - "input_ff_layer.residual_layer.weight": "decoder.input_ff_layer.residual_layer.weight", - "input_ff_layer.residual_layer.bias": "decoder.input_ff_layer.residual_layer.bias", - - "freq_emb.weight": "decoder.freq_emb.weight", - - "horizon_ff_layer.hidden_layer[0].weight": "horizon_ff_layer.hidden_layer[0].weight", - "horizon_ff_layer.hidden_layer[0].bias": "horizon_ff_layer.hidden_layer[0].bias", - "horizon_ff_layer.output_layer.weight": "horizon_ff_layer.output_layer.weight", - "horizon_ff_layer.output_layer.bias": "horizon_ff_layer.output_layer.bias", - "horizon_ff_layer.residual_layer.weight": "horizon_ff_layer.residual_layer.weight", - "horizon_ff_layer.residual_layer.bias": "horizon_ff_layer.residual_layer.bias", + "input_ff_layer.hidden_layer[0].weight": "decoder.input_ff_layer.hidden_layer[0].weight", + "input_ff_layer.hidden_layer[0].bias": "decoder.input_ff_layer.hidden_layer[0].bias", + "input_ff_layer.output_layer.weight": "decoder.input_ff_layer.output_layer.weight", + "input_ff_layer.output_layer.bias": "decoder.input_ff_layer.output_layer.bias", + "input_ff_layer.residual_layer.weight": "decoder.input_ff_layer.residual_layer.weight", + "input_ff_layer.residual_layer.bias": "decoder.input_ff_layer.residual_layer.bias", + "freq_emb.weight": "decoder.freq_emb.weight", + "horizon_ff_layer.hidden_layer[0].weight": "horizon_ff_layer.hidden_layer[0].weight", + "horizon_ff_layer.hidden_layer[0].bias": "horizon_ff_layer.hidden_layer[0].bias", + "horizon_ff_layer.output_layer.weight": "horizon_ff_layer.output_layer.weight", + "horizon_ff_layer.output_layer.bias": "horizon_ff_layer.output_layer.bias", + "horizon_ff_layer.residual_layer.weight": "horizon_ff_layer.residual_layer.weight", + "horizon_ff_layer.residual_layer.bias": "horizon_ff_layer.residual_layer.bias", } - TRANSFORMER_LAYER_MAPPING = { - "stacked_transformer.layers[{i}].self_attn.qkv_proj.weight": "decoder.stacked_transformer.layers[{i}].self_attn.qkv_proj.weight", - "stacked_transformer.layers[{i}].self_attn.qkv_proj.bias": "decoder.stacked_transformer.layers[{i}].self_attn.qkv_proj.bias", - "stacked_transformer.layers[{i}].self_attn.o_proj.weight": "decoder.stacked_transformer.layers[{i}].self_attn.o_proj.weight", - "stacked_transformer.layers[{i}].self_attn.o_proj.bias": "decoder.stacked_transformer.layers[{i}].self_attn.o_proj.bias", - "stacked_transformer.layers[{i}].self_attn.scaling": "decoder.stacked_transformer.layers[{i}].self_attn.scaling", - - "stacked_transformer.layers[{i}].mlp.gate_proj.weight": "decoder.stacked_transformer.layers[{i}].mlp.gate_proj.weight", - "stacked_transformer.layers[{i}].mlp.gate_proj.bias": "decoder.stacked_transformer.layers[{i}].mlp.gate_proj.bias", - "stacked_transformer.layers[{i}].mlp.down_proj.weight": "decoder.stacked_transformer.layers[{i}].mlp.down_proj.weight", - "stacked_transformer.layers[{i}].mlp.down_proj.bias": "decoder.stacked_transformer.layers[{i}].mlp.down_proj.bias", - "stacked_transformer.layers[{i}].mlp.layer_norm.weight": "decoder.stacked_transformer.layers[{i}].mlp.layer_norm.weight", - "stacked_transformer.layers[{i}].mlp.layer_norm.bias": "decoder.stacked_transformer.layers[{i}].mlp.layer_norm.bias", - - "stacked_transformer.layers[{i}].input_layernorm.weight": "decoder.stacked_transformer.layers[{i}].input_layernorm.weight", + "stacked_transformer.layers[{i}].self_attn.qkv_proj.weight": "decoder.stacked_transformer.layers[{i}].self_attn.qkv_proj.weight", + "stacked_transformer.layers[{i}].self_attn.qkv_proj.bias": "decoder.stacked_transformer.layers[{i}].self_attn.qkv_proj.bias", + "stacked_transformer.layers[{i}].self_attn.o_proj.weight": "decoder.stacked_transformer.layers[{i}].self_attn.o_proj.weight", + "stacked_transformer.layers[{i}].self_attn.o_proj.bias": "decoder.stacked_transformer.layers[{i}].self_attn.o_proj.bias", + "stacked_transformer.layers[{i}].self_attn.scaling": "decoder.stacked_transformer.layers[{i}].self_attn.scaling", + "stacked_transformer.layers[{i}].mlp.gate_proj.weight": "decoder.stacked_transformer.layers[{i}].mlp.gate_proj.weight", + "stacked_transformer.layers[{i}].mlp.gate_proj.bias": "decoder.stacked_transformer.layers[{i}].mlp.gate_proj.bias", + "stacked_transformer.layers[{i}].mlp.down_proj.weight": "decoder.stacked_transformer.layers[{i}].mlp.down_proj.weight", + "stacked_transformer.layers[{i}].mlp.down_proj.bias": "decoder.stacked_transformer.layers[{i}].mlp.down_proj.bias", + "stacked_transformer.layers[{i}].mlp.layer_norm.weight": "decoder.stacked_transformer.layers[{i}].mlp.layer_norm.weight", + "stacked_transformer.layers[{i}].mlp.layer_norm.bias": "decoder.stacked_transformer.layers[{i}].mlp.layer_norm.bias", + "stacked_transformer.layers[{i}].input_layernorm.weight": "decoder.stacked_transformer.layers[{i}].input_layernorm.weight", } for old_key, new_key in MODEL_LAYER_MAPPING.items(): diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 79a852ab8c01..49b35cd77890 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -91,30 +91,24 @@ def forward(self, x): return output + residual -class TimesFmRMSNorm(torch.nn.Module): - """Pax rms norm in pytorch.""" - - def __init__( - self, - dim: int, - eps: float = 1e-6, - add_unit_offset: bool = False, - ): +class TimesFmRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + TimesFmRMSNorm is equivalent to T5LayerNorm + """ super().__init__() - self.eps = eps - self.add_unit_offset = add_unit_offset - self.weight = nn.Parameter(torch.zeros(dim)) + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) - def forward(self, x): - output = self._norm(x.float()) - if self.add_unit_offset: - output = output * (1 + self.weight.float()) - else: - output = output * self.weight.float() - return output.type_as(x) + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class TimesFmPositionalEmbedding(nn.Module): @@ -904,7 +898,7 @@ def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> to def forward( self, inputs: Sequence[torch.Tensor], - freq: Optional[Sequence[Union[torch.Tensor,int]]] = None, + freq: Optional[Sequence[Union[torch.Tensor, int]]] = None, window_size: Optional[int] = None, future_target: Optional[torch.Tensor] = None, forecast_context_len: Optional[int] = None, @@ -1034,7 +1028,7 @@ def forward( ] else: # `full_outputs` indexing starts at the forecast horizon. - full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0:self.horizon_len, :] + full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0 : self.horizon_len, :] mean_outputs = full_outputs[:, :, 0] if window_size is not None: @@ -1067,3 +1061,9 @@ def forward( return_tuple.append(decoder_output.all_attentions) return_tuple += [mean_outputs, full_outputs, loss] return tuple(return_tuple) + + +__all__ = [ + "TimesFmModelForPrediction", + "TimesFmPreTrainedModel", +] From d8c2e0d74fd721b0808eb6d317c100792b315f39 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Feb 2025 20:43:45 +0100 Subject: [PATCH 152/242] fix tests --- docs/source/en/perf_infer_gpu_one.md | 2 +- src/transformers/models/timesfm/modeling_timesfm.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 508d7fe3f579..f1ff8a02baeb 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -314,7 +314,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [SigLIP2](https://huggingface.co/docs/transformers/model_doc/siglip2) * [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel) * [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model) -* [TimesFM](https://huggingface.co/docs/transformers/model_doc/timesfm#transformers.TimesFmModel) +* [TimesFM](https://huggingface.co/docs/transformers/model_doc/timesfm#transformers.TimesFmDecoder) * [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) * [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel) * [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 49b35cd77890..b03042faf0d9 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -1047,8 +1047,8 @@ def forward( if return_dict: return TimesFmOutputForPrediction( last_hidden_state=decoder_output.last_hidden_state, - attentions=decoder_output.all_attentions if output_attentions else None, - hidden_states=decoder_output.all_hidden_states if output_hidden_states else None, + attentions=decoder_output.attentions if output_attentions else None, + hidden_states=decoder_output.hidden_states if output_hidden_states else None, mean_predictions=mean_outputs, full_predictions=full_outputs, loss=loss, @@ -1056,9 +1056,9 @@ def forward( else: return_tuple = [decoder_output.last_hidden_state] if output_hidden_states: - return_tuple.append(decoder_output.all_hidden_states) + return_tuple.append(decoder_output.hidden_states) if output_attentions: - return_tuple.append(decoder_output.all_attentions) + return_tuple.append(decoder_output.attentions) return_tuple += [mean_outputs, full_outputs, loss] return tuple(return_tuple) @@ -1066,4 +1066,5 @@ def forward( __all__ = [ "TimesFmModelForPrediction", "TimesFmPreTrainedModel", + "TimesFmDecoder", ] From a75b8e7a2dc19ece994457d46295a2ca88debd62 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Feb 2025 20:57:31 +0100 Subject: [PATCH 153/242] use initializer_range --- .../models/timesfm/configuration_timesfm.py | 9 ++++----- .../models/timesfm/modeling_timesfm.py | 18 +++++++++--------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 81ce310fe098..ea110df5a573 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -66,9 +66,8 @@ class TimesFmConfig(PretrainedConfig): The dropout probability for the attention scores. use_positional_embedding (`bool`, *optional*, defaults to `True`): Whether to add positional embeddings. - initializer_factor (`float`, *optional*, defaults to 1.0): - A factor for initializing all weight matrices (should be kept to 1, used internally for initialization - testing). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. """ model_type = "timesfm" @@ -97,7 +96,7 @@ def __init__( pad_val: float = 1123581321.0, attention_dropout: float = 0.0, use_positional_embedding: bool = True, - initializer_factor: float = 1.0, + initializer_range: float = 0.02, **kwargs, ): self.patch_len = patch_len @@ -115,7 +114,7 @@ def __init__( self.rms_norm_eps = rms_norm_eps self.attention_dropout = attention_dropout self.use_positional_embedding = use_positional_embedding - self.initializer_factor = initializer_factor + self.initializer_range = initializer_range super().__init__( is_encoder_decoder=self.is_encoder_decoder, diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index b03042faf0d9..c95ad08e71f9 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -617,10 +617,10 @@ class TimesFmPreTrainedModel(PreTrainedModel): def _init_weights(self, module): if isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0, std=self.config.initializer_factor) + module.weight.data.normal_(mean=0, std=self.config.initializer_range) elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0, std=self.config.initializer_factor) + module.weight.data.normal_(mean=0, std=self.config.initializer_range) if module.bias is not None: nn.init.zeros_(module.bias) @@ -633,12 +633,12 @@ def _init_weights(self, module): elif isinstance(module, TimesFmMLP): # Initialize gate projection - module.gate_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor) + module.gate_proj.weight.data.normal_(mean=0, std=self.config.initializer_range) if module.gate_proj.bias is not None: nn.init.zeros_(module.gate_proj.bias) # Initialize down projection - module.down_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor) + module.down_proj.weight.data.normal_(mean=0, std=self.config.initializer_range) if module.down_proj.bias is not None: nn.init.zeros_(module.down_proj.bias) @@ -648,12 +648,12 @@ def _init_weights(self, module): elif isinstance(module, TimesFmAttention): # Initialize qkv projection - module.qkv_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor) + module.qkv_proj.weight.data.normal_(mean=0, std=self.config.initializer_range) if module.qkv_proj.bias is not None: nn.init.zeros_(module.qkv_proj.bias) # Initialize output projection - module.o_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor) + module.o_proj.weight.data.normal_(mean=0, std=self.config.initializer_range) if module.o_proj.bias is not None: nn.init.zeros_(module.o_proj.bias) @@ -662,17 +662,17 @@ def _init_weights(self, module): elif isinstance(module, TimesFmResidualBlock): # Initialize hidden layer - module.hidden_layer[0].weight.data.normal_(mean=0, std=self.config.initializer_factor) + module.hidden_layer[0].weight.data.normal_(mean=0, std=self.config.initializer_range) if module.hidden_layer[0].bias is not None: nn.init.zeros_(module.hidden_layer[0].bias) # Initialize output layer - module.output_layer.weight.data.normal_(mean=0, std=self.config.initializer_factor) + module.output_layer.weight.data.normal_(mean=0, std=self.config.initializer_range) if module.output_layer.bias is not None: nn.init.zeros_(module.output_layer.bias) # Initialize residual layer - module.residual_layer.weight.data.normal_(mean=0, std=self.config.initializer_factor) + module.residual_layer.weight.data.normal_(mean=0, std=self.config.initializer_range) if module.residual_layer.bias is not None: nn.init.zeros_(module.residual_layer.bias) From 5352cda52f5bfbb733d1d12eddf4dc1037e984d5 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Feb 2025 21:14:30 +0100 Subject: [PATCH 154/242] TimesFmModel instead of TimesFmDecoder --- docs/source/en/model_doc/timesfm.md | 4 ++-- src/transformers/__init__.py | 4 ++-- .../models/timesfm/configuration_timesfm.py | 2 +- src/transformers/models/timesfm/modeling_timesfm.py | 12 ++++++------ src/transformers/utils/dummy_pt_objects.py | 2 +- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/source/en/model_doc/timesfm.md b/docs/source/en/model_doc/timesfm.md index 144b29769faf..88366594e803 100644 --- a/docs/source/en/model_doc/timesfm.md +++ b/docs/source/en/model_doc/timesfm.md @@ -34,9 +34,9 @@ The original code can be found [here](https://github.com/google-research/timesfm [[autodoc]] TimesFmConfig -## TimesFmDecoder +## TimesFmModel -[[autodoc]] TimesFmDecoder +[[autodoc]] TimesFmModel - forward ## TimesFmModelForPrediction diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 55146c35e7fe..06b3538113d2 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3742,7 +3742,7 @@ ) _import_structure["models.timesfm"].extend( [ - "TimesFmDecoder", + "TimesFmModel", "TimesFmModelForPrediction", "TimesFmPreTrainedModel", ] @@ -8455,7 +8455,7 @@ TimeSeriesTransformerPreTrainedModel, ) from .models.timesfm import ( - TimesFmDecoder, + TimesFmModel, TimesFmModelForPrediction, TimesFmPreTrainedModel, ) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index ea110df5a573..6f17b0c7bfcc 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -26,7 +26,7 @@ class TimesFmConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`TimesFmModelForPrediction`] or a [`TFTimesFmDecoder`]. It is used to + This is the configuration class to store the configuration of a [`TimesFmModelForPrediction`] or a [`TFTimesFmModel`]. It is used to instantiate a TimesFM model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the TimesFM [google/timesfm-1.0-200m](https://huggingface.co/google/timesfm-1.0-200m) architecture. diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index c95ad08e71f9..6d76a17ea646 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -29,7 +29,7 @@ @dataclass -class TimesFmDecoderOutput(BaseModelOutput): +class TimesFmOutput(BaseModelOutput): loc: Optional[torch.Tensor] = None scale: Optional[torch.Tensor] = None @@ -680,7 +680,7 @@ def _init_weights(self, module): pass -class TimesFmDecoder(TimesFmPreTrainedModel): +class TimesFmModel(TimesFmPreTrainedModel): """Patched time-series decoder without any specific output layer.""" def __init__(self, config: TimesFmConfig): @@ -766,7 +766,7 @@ def forward( output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, - ) -> Union[TimesFmDecoderOutput, tuple[torch.Tensor, ...]]: + ) -> Union[TimesFmOutput, tuple[torch.Tensor, ...]]: model_input, patched_padding, stats, _ = self._preprocess_input( input_ts=input_ts, input_padding=input_padding, @@ -786,7 +786,7 @@ def forward( all_hidden_states = None if return_dict: - return TimesFmDecoderOutput( + return TimesFmOutput( last_hidden_state=transformer_output.last_hidden_state, hidden_states=all_hidden_states, attentions=transformer_output.attentions if output_attentions else None, @@ -813,7 +813,7 @@ def __init__(self, config: TimesFmConfig): self.context_len = config.context_len self.horizon_len = config.horizon_len - self.decoder = TimesFmDecoder(config) + self.decoder = TimesFmModel(config) # quantile and mean output self.horizon_ff_layer = TimesFmResidualBlock( @@ -1066,5 +1066,5 @@ def forward( __all__ = [ "TimesFmModelForPrediction", "TimesFmPreTrainedModel", - "TimesFmDecoder", + "TimesFmModel", ] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 3e9a26d2b138..e10821e84c05 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -9487,7 +9487,7 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class TimesFmDecoder(metaclass=DummyObject): +class TimesFmModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): From f460370e532fd0a25bebad397f2109f56d614191 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Feb 2025 21:24:56 +0100 Subject: [PATCH 155/242] TimesFmPositionalEmbedding takes config for its init --- .../models/timesfm/configuration_timesfm.py | 10 ++++++++ .../models/timesfm/modeling_timesfm.py | 24 ++++--------------- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index 6f17b0c7bfcc..d8eaabdc9bfa 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -68,6 +68,12 @@ class TimesFmConfig(PretrainedConfig): Whether to add positional embeddings. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + min_timescale (`int`, *optional*, defaults to 1): + The start of the geometric positional index. Determines the periodicity of + the added signal. + max_timescale (`int`, *optional*, defaults to 10_000): + The end of the geometric positional index. Determines the frequency of the + added signal. """ model_type = "timesfm" @@ -97,6 +103,8 @@ def __init__( attention_dropout: float = 0.0, use_positional_embedding: bool = True, initializer_range: float = 0.02, + min_timescale: int = 1, + max_timescale: int = 10_000, **kwargs, ): self.patch_len = patch_len @@ -115,6 +123,8 @@ def __init__( self.attention_dropout = attention_dropout self.use_positional_embedding = use_positional_embedding self.initializer_range = initializer_range + self.min_timescale = min_timescale + self.max_timescale = max_timescale super().__init__( is_encoder_decoder=self.is_encoder_decoder, diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 6d76a17ea646..f4070fa889df 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -113,25 +113,13 @@ def extra_repr(self): class TimesFmPositionalEmbedding(nn.Module): """Generates position embedding for a given 1-d sequence. - - Attributes: - embedding_dims: Dimension of the embedding to be generated. - min_timescale: Start of the geometric index. Determines the periodicity of - the added signal. Defaults to 1. - max_timescale: End of the geometric index. Determines the frequency of the - added signal. Defaults to 10_000. """ - def __init__( - self, - embedding_dims: int, - min_timescale: int = 1, - max_timescale: int = 10_000, - ) -> None: + def __init__(self, config: TimesFmConfig) -> None: super().__init__() - self.min_timescale = min_timescale - self.max_timescale = max_timescale - self.embedding_dims = embedding_dims + self.min_timescale = config.min_timescale + self.max_timescale = config.max_timescale + self.embedding_dims = config.model_dim def forward(self, seq_length=None, position=None): """Generates a Tensor of sinusoids with different frequencies. @@ -695,9 +683,7 @@ def __init__(self, config: TimesFmConfig): self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.model_dim) self.stacked_transformer = TimesFmStackedDecoder(config=config) if self.config.use_positional_embedding: - self.position_emb = TimesFmPositionalEmbedding( - embedding_dims=self.config.model_dim, - ) + self.position_emb = TimesFmPositionalEmbedding(config=config) # Initialize weights and apply final processing self.post_init() From 913f3608be439cdbf534f27fed9178a7298f76ab Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Feb 2025 21:35:42 +0100 Subject: [PATCH 156/242] 2.0-500m-pytorch default configs --- .../models/timesfm/configuration_timesfm.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index d8eaabdc9bfa..c4577c9e599a 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -29,7 +29,7 @@ class TimesFmConfig(PretrainedConfig): This is the configuration class to store the configuration of a [`TimesFmModelForPrediction`] or a [`TFTimesFmModel`]. It is used to instantiate a TimesFM model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the TimesFM - [google/timesfm-1.0-200m](https://huggingface.co/google/timesfm-1.0-200m) architecture. + [google/timesfm-2.0-500m-pytorch](https://huggingface.co/google/timesfm-2.0-500m-pytorch) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -43,7 +43,7 @@ class TimesFmConfig(PretrainedConfig): The length of the prediction horizon. freq_size (`int`, *optional*, defaults to 3): The number of frequency embeddings. - num_layers (`int`, *optional*, defaults to 20): + num_layers (`int`, *optional*, defaults to 50): Number of Transformer layers. model_dim (`int`, *optional*, defaults to 1280): Size of the hidden layers in the feed-forward networks. @@ -64,7 +64,7 @@ class TimesFmConfig(PretrainedConfig): The value used to pad the predictions. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout probability for the attention scores. - use_positional_embedding (`bool`, *optional*, defaults to `True`): + use_positional_embedding (`bool`, *optional*, defaults to `False`): Whether to add positional embeddings. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. @@ -91,7 +91,7 @@ def __init__( context_len: int = 512, horizon_len: int = 128, freq_size: int = 3, - num_layers: int = 20, + num_layers: int = 50, model_dim: int = 1280, intermediate_size: int = 1280, head_dim: int = 80, @@ -101,7 +101,7 @@ def __init__( quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], pad_val: float = 1123581321.0, attention_dropout: float = 0.0, - use_positional_embedding: bool = True, + use_positional_embedding: bool = False, initializer_range: float = 0.02, min_timescale: int = 1, max_timescale: int = 10_000, From 02e62c6dc91aef342430cf204a6b76cfb99d99c5 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Feb 2025 21:42:57 +0100 Subject: [PATCH 157/242] use TimesFmModel --- docs/source/en/perf_infer_gpu_one.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index f1ff8a02baeb..508d7fe3f579 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -314,7 +314,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [SigLIP2](https://huggingface.co/docs/transformers/model_doc/siglip2) * [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel) * [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model) -* [TimesFM](https://huggingface.co/docs/transformers/model_doc/timesfm#transformers.TimesFmDecoder) +* [TimesFM](https://huggingface.co/docs/transformers/model_doc/timesfm#transformers.TimesFmModel) * [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) * [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel) * [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel) From 4466315657cd7f7180741cc3f1207e320e81d706 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 16 Feb 2025 19:47:09 +0100 Subject: [PATCH 158/242] fix formatting --- src/transformers/models/timesfm/modeling_timesfm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index f4070fa889df..c2c8a13f3b6e 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -112,8 +112,7 @@ def extra_repr(self): class TimesFmPositionalEmbedding(nn.Module): - """Generates position embedding for a given 1-d sequence. - """ + """Generates position embedding for a given 1-d sequence.""" def __init__(self, config: TimesFmConfig) -> None: super().__init__() From df7bbb0610eb88ccd8140799e955ce02e1e10fcf Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 17 Feb 2025 08:59:43 +0100 Subject: [PATCH 159/242] ignore TimesFmModel for testing --- utils/check_repo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/utils/check_repo.py b/utils/check_repo.py index 3b3dddf9cf63..1efee6857048 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -144,6 +144,7 @@ "MllamaVisionModel", # Building part of bigger (tested) model. # TODO: add tests "Emu3VQVAE", # Building part of bigger (tested) model "Emu3TextModel", # Building part of bigger (tested) model + "TimesFmModel", # Building part of bigger (tested) model ] ) From c0a4f487598b6c4c0bf1084620d785e378ae0fe5 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 27 Feb 2025 11:57:20 +0100 Subject: [PATCH 160/242] fix docstring --- src/transformers/models/timesfm/configuration_timesfm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py index c4577c9e599a..570d39c02221 100644 --- a/src/transformers/models/timesfm/configuration_timesfm.py +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -71,7 +71,7 @@ class TimesFmConfig(PretrainedConfig): min_timescale (`int`, *optional*, defaults to 1): The start of the geometric positional index. Determines the periodicity of the added signal. - max_timescale (`int`, *optional*, defaults to 10_000): + max_timescale (`int`, *optional*, defaults to 10000): The end of the geometric positional index. Determines the frequency of the added signal. """ From 71bda445b7e58e02a9b6cda963c113346d9d7a01 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 27 Feb 2025 12:04:51 +0100 Subject: [PATCH 161/242] override generate as its not needed --- .../models/timesfm/modeling_timesfm.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index c2c8a13f3b6e..3e9cc01f8178 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -666,6 +666,32 @@ def _init_weights(self, module): elif isinstance(module, TimesFmPositionalEmbedding): pass + def generate(self, *args, **kwargs): + """ + This method is disabled for TimesFM models. TimesFM models are designed for time series forecasting and should be used + with the forward() method instead. For forecasting, use: + + ```python + # For basic forecasting: + outputs = model(input_ts=your_time_series, input_padding=your_padding, freq=your_frequency) + + # For prediction with quantiles: + outputs = model.forward( + inputs=your_time_series_list, + freq=your_frequencies, + window_size=optional_window_size, + future_target=optional_target, + forecast_context_len=optional_context_length + ) + ``` + + See the model's documentation for more details on the forward method parameters. + """ + raise NotImplementedError( + "The generate() method is not implemented for TimesFM models as they are designed for time series " + "forecasting. Please use the forward() method instead. See the docstring of this method for usage examples." + ) + class TimesFmModel(TimesFmPreTrainedModel): """Patched time-series decoder without any specific output layer.""" From b7e75e9f743f70284e700e8ad8c8f6e8c724837f Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 27 Feb 2025 12:28:06 +0100 Subject: [PATCH 162/242] add doc strings --- .../models/timesfm/modeling_timesfm.py | 100 +++++++++++++----- 1 file changed, 75 insertions(+), 25 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 3e9cc01f8178..bf66c4d4f5f7 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -25,9 +25,15 @@ from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from .configuration_timesfm import TimesFmConfig +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "TimesFmConfig" + + @dataclass class TimesFmOutput(BaseModelOutput): loc: Optional[torch.Tensor] = None @@ -594,6 +600,27 @@ def expand_t(key_mask): return torch.minimum(a, b) # Element-wise minimum, similar to jnp.minimum +TIMESFM_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`TimesFmConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare TimesFM Model outputting raw hidden-states without any specific head on top.", + TIMESFM_START_DOCSTRING, +) class TimesFmPreTrainedModel(PreTrainedModel): """handles the loading for all models.""" @@ -693,6 +720,27 @@ def generate(self, *args, **kwargs): ) +TIMESFM_INPUTS_DOCSTRING = r""" + Args: + inputs: list of time series forecast contexts. Each context time series + should be a torch Tensor of potentially different context lengths. + freq: frequency of each context time series in the inputs. 0 for high frequency + (default), 1 for medium, and 2 for low. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare TimesFM Model outputting raw hidden-states without any specific head on top.", + TIMESFM_START_DOCSTRING, +) class TimesFmModel(TimesFmPreTrainedModel): """Patched time-series decoder without any specific output layer.""" @@ -769,17 +817,23 @@ def _preprocess_input( return model_input, patched_padding, stats, patched_inputs + @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) def forward( self, - input_ts: torch.Tensor, + inputs: torch.Tensor, input_padding: torch.LongTensor, freq: torch.Tensor, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ) -> Union[TimesFmOutput, tuple[torch.Tensor, ...]]: + """ + input_padding (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The padding indicator of the time series. + """ + model_input, patched_padding, stats, _ = self._preprocess_input( - input_ts=input_ts, + input_ts=inputs, input_padding=input_padding, ) f_emb = self.freq_emb(freq) # B x 1 x D @@ -906,6 +960,8 @@ def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> to losses.append(loss.mean()) return torch.stack(losses).mean() + @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TimesFmOutputForPrediction, config_class=_CONFIG_FOR_DOC) def forward( self, inputs: Sequence[torch.Tensor], @@ -919,31 +975,25 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[TimesFmOutputForPrediction, tuple[torch.Tensor, ...]]: - """Forecasts on a list of time series. - - Args: - inputs: list of time series forecast contexts. Each context time series - should be a torch Tensor of potentially different context lengths. - freq: frequency of each context time series in the inputs. 0 for high frequency - (default), 1 for medium, and 2 for low. - window_size: window size of trend + residual decomposition. If None then - we do not do decomposition. - future_target: optional future target time series to be used for loss computation. - forecast_context_len: optional max context length. - return_forecast_on_context: True to return the forecast on the context - when available, i.e. after the first input patch. - truncate_negative: truncate to only non-negative values if all the contexts - have non-negative values. - output_attentions: Whether to return the attentions. - output_hidden_states: Whether to return the hidden states. - return_dict: Whether to return a TimesFmOutputForPrediction object. + r""" + window_size (`int`, *optional*): + Window size of trend + residual decomposition. If None thenwe do not do decomposition. + future_target (`torch.Tensor`, *optional*): + Optional future target time series to be used for loss computation. + forecast_context_len (`int`, *optional*): + Optional max context length. + return_forecast_on_context (`bool`, *optional*): + True to return the forecast on the context when available, i.e. after the first input patch. + truncate_negative (`bool`, *optional*): + Truncate to only non-negative values if all the contexts have non-negative values. + have non-ne ative values. Returns: - A TimesFmOutputForPrediction object containing: - - the mean forecast of size (# inputs, # forecast horizon), - - the full forecast (mean + quantiles) of size - (# inputs, # forecast horizon, 1 + # quantiles). - - loss: the mean squared error loss + quantile loss if future_target is provided. + A TimesFmOutputForPrediction object or a tuple containing: + - the mean forecast of size (# inputs, # forecast horizon), + - the full forecast (mean + quantiles) of size + (# inputs, # forecast horizon, 1 + # quantiles). + - loss: the mean squared error loss + quantile loss if future_target is provided. """ if return_dict is None: return_dict = self.config.use_return_dict From f76116bc65f6f9c467ab3ae8c210aabb906a1b89 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 27 Feb 2025 12:49:28 +0100 Subject: [PATCH 163/242] fix logging --- src/transformers/models/timesfm/modeling_timesfm.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index bf66c4d4f5f7..efe417752534 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -14,7 +14,6 @@ # limitations under the License. """PyTorch TimesFM model.""" -import logging import math from dataclasses import dataclass from typing import List, Optional, Sequence, Tuple, Union @@ -25,7 +24,7 @@ from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_timesfm import TimesFmConfig @@ -1023,7 +1022,7 @@ def forward( freq = new_freqs if freq is None: - logging.info("No frequency provided via `freq`. Default to high (0).") + logger.info("No frequency provided via `freq`. Default to high (0).") freq = [0] * len(inputs) if output_attentions is None: @@ -1124,8 +1123,4 @@ def forward( return tuple(return_tuple) -__all__ = [ - "TimesFmModelForPrediction", - "TimesFmPreTrainedModel", - "TimesFmModel", -] +__all__ = ["TimesFmModelForPrediction", "TimesFmPreTrainedModel", "TimesFmModel"] From 0026ba6ff07c5a8ceabef4755c1ad182a13bef7a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 27 Feb 2025 13:10:37 +0100 Subject: [PATCH 164/242] add docstrings to output data classes --- .../models/timesfm/modeling_timesfm.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index efe417752534..f7211b7a847c 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -35,12 +35,30 @@ @dataclass class TimesFmOutput(BaseModelOutput): + """ + Args: + loc (`torch.Tensor` of shape `(batch_size, )`): + The mean of the time series inputs. + scale (`torch.Tensor` of shape `(batch_size,)`): + The scale of the time series inputs. + """ + loc: Optional[torch.Tensor] = None scale: Optional[torch.Tensor] = None @dataclass class TimesFmOutputForPrediction(BaseModelOutput): + """ + Args: + mean_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`): + The mean predictions of the time series. + full_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`): + The full predictions of the time series including the mean and the quantiles. + loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `future_target` is provided): + The loss of the TimesFM model. + """ + mean_predictions: Optional[torch.Tensor] = None full_predictions: Optional[torch.Tensor] = None loss: Optional[Union[torch.Tensor, float]] = None From 380e6bff37858c9ac712d464e200c4eb0be822c7 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 28 Feb 2025 09:06:18 +0100 Subject: [PATCH 165/242] add _CHECKPOINT_FOR_DOC --- .../models/timesfm/modeling_timesfm.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index f7211b7a847c..0219cabdd409 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -24,12 +24,19 @@ from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) from .configuration_timesfm import TimesFmConfig logger = logging.get_logger(__name__) +_CHECKPOINT_FOR_DOC = "google/timesfm-2.0-500m-pytorch" _CONFIG_FOR_DOC = "TimesFmConfig" @@ -979,6 +986,11 @@ def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> to @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TimesFmOutputForPrediction, config_class=_CONFIG_FOR_DOC) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TimesFmOutputForPrediction, + config_class=_CONFIG_FOR_DOC, + ) def forward( self, inputs: Sequence[torch.Tensor], From 8deeb3e191b3671bc1d74dbfe77b736a066c3d34 Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Fri, 28 Feb 2025 11:24:17 -0800 Subject: [PATCH 166/242] fix comments --- .../models/timesfm/modeling_timesfm.py | 64 +++++++++---------- 1 file changed, 29 insertions(+), 35 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 0219cabdd409..fa4526525f62 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -71,22 +71,19 @@ class TimesFmOutputForPrediction(BaseModelOutput): loss: Optional[Union[torch.Tensor, float]] = None -class TimesFmMLP(nn.Module): +class TimesFmResidualBlock(nn.Module): """Pax MLP in pytorch.""" - def __init__( - self, - hidden_size: int, - intermediate_size: int, - ): + def __init__(self, config: TimesFmConfig): super().__init__() + hidden_size = config.model_dim + intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(hidden_size, intermediate_size) self.down_proj = nn.Linear(intermediate_size, hidden_size) - self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6) def forward(self, x, paddings=None): - gate_inp = self.layer_norm(x) - gate = self.gate_proj(gate_inp) + gate = self.gate_proj(x) gate = F.relu(gate) outputs = self.down_proj(gate) if paddings is not None: @@ -94,41 +91,33 @@ def forward(self, x, paddings=None): return outputs + x -class TimesFmResidualBlock(nn.Module): +class TimesFmMlpBlock(nn.Module): """TimesFM residual block.""" def __init__(self, input_dims, hidden_dims, output_dims): super().__init__() - self.input_dims = input_dims - self.hidden_dims = hidden_dims - self.output_dims = output_dims - - # Hidden Layer - self.hidden_layer = nn.Sequential( - nn.Linear(input_dims, hidden_dims), - nn.SiLU(), - ) - - # Output Layer + self.linear = nn.Linear(input_dims, hidden_dims) + self.activation = nn.SiLU() self.output_layer = nn.Linear(hidden_dims, output_dims) - # Residual Layer self.residual_layer = nn.Linear(input_dims, output_dims) def forward(self, x): - hidden = self.hidden_layer(x) + hidden = self.linear(x) + hidden = self.activation(hidden) output = self.output_layer(hidden) residual = self.residual_layer(x) return output + residual class TimesFmRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): + def __init__(self, config: TimesFmConfig): """ TimesFmRMSNorm is equivalent to T5LayerNorm """ super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps + + self.weight = nn.Parameter(torch.ones(config.model_dim)) + self.variance_epsilon = config.rms_norm_eps def forward(self, hidden_states): input_dtype = hidden_states.dtype @@ -144,7 +133,7 @@ def extra_repr(self): class TimesFmPositionalEmbedding(nn.Module): """Generates position embedding for a given 1-d sequence.""" - def __init__(self, config: TimesFmConfig) -> None: + def __init__(self, config: TimesFmConfig): super().__init__() self.min_timescale = config.min_timescale self.max_timescale = config.max_timescale @@ -376,8 +365,12 @@ def __init__(self, config: TimesFmConfig): attention_class = TIMESFM_ATTENTION_CLASSES[config._attn_implementation] self.self_attn = attention_class(config) - self.mlp = TimesFmMLP(config.model_dim, config.intermediate_size) - self.input_layernorm = TimesFmRMSNorm(config.model_dim, eps=config.rms_norm_eps) + self.residual_block = TimesFmResidualBlock(config) + self.rms_norm = TimesFmRMSNorm(config) + self.layer_norm = nn.LayerNorm( + normalized_shape=config.model_dim, + eps=config.rms_norm_eps, + ) def forward( self, @@ -390,7 +383,7 @@ def forward( ) -> torch.Tensor: # Self Attention residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.rms_norm(hidden_states) hidden_states, scores = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, @@ -401,7 +394,8 @@ def forward( hidden_states = residual + hidden_states # MLP - hidden_states = self.mlp(hidden_states, paddings=paddings) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.residual_block(hidden_states, paddings=paddings) return scores, hidden_states @@ -669,7 +663,7 @@ def _init_weights(self, module): elif isinstance(module, TimesFmRMSNorm): nn.init.zeros_(module.weight) - elif isinstance(module, TimesFmMLP): + elif isinstance(module, TimesFmResidualBlock): # Initialize gate projection module.gate_proj.weight.data.normal_(mean=0, std=self.config.initializer_range) if module.gate_proj.bias is not None: @@ -698,7 +692,7 @@ def _init_weights(self, module): # Initialize scaling parameter nn.init.ones_(module.scaling) - elif isinstance(module, TimesFmResidualBlock): + elif isinstance(module, TimesFmMlpBlock): # Initialize hidden layer module.hidden_layer[0].weight.data.normal_(mean=0, std=self.config.initializer_range) if module.hidden_layer[0].bias is not None: @@ -772,7 +766,7 @@ def __init__(self, config: TimesFmConfig): super().__init__(config) self.config = config - self.input_ff_layer = TimesFmResidualBlock( + self.input_ff_layer = TimesFmMlpBlock( input_dims=2 * config.patch_len, output_dims=config.model_dim, hidden_dims=config.intermediate_size, @@ -905,7 +899,7 @@ def __init__(self, config: TimesFmConfig): self.decoder = TimesFmModel(config) # quantile and mean output - self.horizon_ff_layer = TimesFmResidualBlock( + self.horizon_ff_layer = TimesFmMlpBlock( input_dims=config.model_dim, output_dims=config.horizon_len * (1 + len(config.quantiles)), hidden_dims=config.intermediate_size, From 92e0b41ed488a11165ae8d88bc022dadda9b6e62 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 1 Mar 2025 12:32:39 +0100 Subject: [PATCH 167/242] Revert "fix comments" This reverts commit 8deeb3e191b3671bc1d74dbfe77b736a066c3d34. --- .../models/timesfm/modeling_timesfm.py | 64 ++++++++++--------- 1 file changed, 35 insertions(+), 29 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index fa4526525f62..0219cabdd409 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -71,19 +71,22 @@ class TimesFmOutputForPrediction(BaseModelOutput): loss: Optional[Union[torch.Tensor, float]] = None -class TimesFmResidualBlock(nn.Module): +class TimesFmMLP(nn.Module): """Pax MLP in pytorch.""" - def __init__(self, config: TimesFmConfig): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + ): super().__init__() - hidden_size = config.model_dim - intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(hidden_size, intermediate_size) self.down_proj = nn.Linear(intermediate_size, hidden_size) + self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6) def forward(self, x, paddings=None): - gate = self.gate_proj(x) + gate_inp = self.layer_norm(x) + gate = self.gate_proj(gate_inp) gate = F.relu(gate) outputs = self.down_proj(gate) if paddings is not None: @@ -91,33 +94,41 @@ def forward(self, x, paddings=None): return outputs + x -class TimesFmMlpBlock(nn.Module): +class TimesFmResidualBlock(nn.Module): """TimesFM residual block.""" def __init__(self, input_dims, hidden_dims, output_dims): super().__init__() - self.linear = nn.Linear(input_dims, hidden_dims) - self.activation = nn.SiLU() + self.input_dims = input_dims + self.hidden_dims = hidden_dims + self.output_dims = output_dims + + # Hidden Layer + self.hidden_layer = nn.Sequential( + nn.Linear(input_dims, hidden_dims), + nn.SiLU(), + ) + + # Output Layer self.output_layer = nn.Linear(hidden_dims, output_dims) + # Residual Layer self.residual_layer = nn.Linear(input_dims, output_dims) def forward(self, x): - hidden = self.linear(x) - hidden = self.activation(hidden) + hidden = self.hidden_layer(x) output = self.output_layer(hidden) residual = self.residual_layer(x) return output + residual class TimesFmRMSNorm(nn.Module): - def __init__(self, config: TimesFmConfig): + def __init__(self, hidden_size, eps=1e-6): """ TimesFmRMSNorm is equivalent to T5LayerNorm """ super().__init__() - - self.weight = nn.Parameter(torch.ones(config.model_dim)) - self.variance_epsilon = config.rms_norm_eps + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype @@ -133,7 +144,7 @@ def extra_repr(self): class TimesFmPositionalEmbedding(nn.Module): """Generates position embedding for a given 1-d sequence.""" - def __init__(self, config: TimesFmConfig): + def __init__(self, config: TimesFmConfig) -> None: super().__init__() self.min_timescale = config.min_timescale self.max_timescale = config.max_timescale @@ -365,12 +376,8 @@ def __init__(self, config: TimesFmConfig): attention_class = TIMESFM_ATTENTION_CLASSES[config._attn_implementation] self.self_attn = attention_class(config) - self.residual_block = TimesFmResidualBlock(config) - self.rms_norm = TimesFmRMSNorm(config) - self.layer_norm = nn.LayerNorm( - normalized_shape=config.model_dim, - eps=config.rms_norm_eps, - ) + self.mlp = TimesFmMLP(config.model_dim, config.intermediate_size) + self.input_layernorm = TimesFmRMSNorm(config.model_dim, eps=config.rms_norm_eps) def forward( self, @@ -383,7 +390,7 @@ def forward( ) -> torch.Tensor: # Self Attention residual = hidden_states - hidden_states = self.rms_norm(hidden_states) + hidden_states = self.input_layernorm(hidden_states) hidden_states, scores = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, @@ -394,8 +401,7 @@ def forward( hidden_states = residual + hidden_states # MLP - hidden_states = self.layer_norm(hidden_states) - hidden_states = self.residual_block(hidden_states, paddings=paddings) + hidden_states = self.mlp(hidden_states, paddings=paddings) return scores, hidden_states @@ -663,7 +669,7 @@ def _init_weights(self, module): elif isinstance(module, TimesFmRMSNorm): nn.init.zeros_(module.weight) - elif isinstance(module, TimesFmResidualBlock): + elif isinstance(module, TimesFmMLP): # Initialize gate projection module.gate_proj.weight.data.normal_(mean=0, std=self.config.initializer_range) if module.gate_proj.bias is not None: @@ -692,7 +698,7 @@ def _init_weights(self, module): # Initialize scaling parameter nn.init.ones_(module.scaling) - elif isinstance(module, TimesFmMlpBlock): + elif isinstance(module, TimesFmResidualBlock): # Initialize hidden layer module.hidden_layer[0].weight.data.normal_(mean=0, std=self.config.initializer_range) if module.hidden_layer[0].bias is not None: @@ -766,7 +772,7 @@ def __init__(self, config: TimesFmConfig): super().__init__(config) self.config = config - self.input_ff_layer = TimesFmMlpBlock( + self.input_ff_layer = TimesFmResidualBlock( input_dims=2 * config.patch_len, output_dims=config.model_dim, hidden_dims=config.intermediate_size, @@ -899,7 +905,7 @@ def __init__(self, config: TimesFmConfig): self.decoder = TimesFmModel(config) # quantile and mean output - self.horizon_ff_layer = TimesFmMlpBlock( + self.horizon_ff_layer = TimesFmResidualBlock( input_dims=config.model_dim, output_dims=config.horizon_len * (1 + len(config.quantiles)), hidden_dims=config.intermediate_size, From 33fde148ef35c6ba368d58256a61d5002e035412 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 1 Mar 2025 12:50:20 +0100 Subject: [PATCH 168/242] add _prepare_4d_attention_mask --- .../models/timesfm/modeling_timesfm.py | 144 ++++++++---------- 1 file changed, 64 insertions(+), 80 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 0219cabdd409..a2bc773ffdca 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -156,7 +156,7 @@ def forward(self, seq_length=None, position=None): Args: seq_length: an optional Python int defining the output sequence length. if the `position` argument is specified. - position: [B, seq_length], optional position for each token in the + position: [B, seq_length], optional position for each token in the sequence, only required when the sequence is packed. Returns: @@ -411,7 +411,6 @@ class TimesFmStackedDecoder(nn.Module): def __init__(self, config: TimesFmConfig): super().__init__() - self.layers = nn.ModuleList([TimesFmDecoderLayer(config) for _ in range(config.num_layers)]) def forward( @@ -423,9 +422,15 @@ def forward( output_attentions: bool = False, output_hidden_states: bool = False, ) -> BaseModelOutput: - padding_mask = timesfm_convert_paddings_to_mask(paddings, hidden_states.dtype) - atten_mask = timesfm_causal_mask(hidden_states) - mask = timesfm_merge_masks(padding_mask, atten_mask) + # Convert paddings to attention mask and combine with causal mask + attention_mask = _prepare_4d_attention_mask( + attention_mask=paddings, + sequence_length=hidden_states.shape[1], + dtype=hidden_states.dtype, + device=hidden_states.device, + is_causal=True, + ) + all_attentions = [] all_hidden_states = [] @@ -434,7 +439,7 @@ def forward( kv_cache = kv_caches[i] if kv_caches is not None else None scores, hidden_states = layer( hidden_states=hidden_states, - attention_mask=mask, + attention_mask=attention_mask, paddings=paddings, kv_write_indices=kv_write_indices, kv_cache=kv_cache, @@ -452,7 +457,59 @@ def forward( ) -# Move utility functions here +def timesfm_get_large_negative_number(dtype: torch.dtype) -> torch.Tensor: + """Returns a large negative value for the given dtype.""" + if dtype.is_floating_point: + dtype_max = torch.finfo(dtype).max + else: + dtype_max = torch.iinfo(dtype).max + return torch.tensor(-0.7 * dtype_max, dtype=dtype) + + +def _prepare_4d_attention_mask( + attention_mask: Optional[torch.Tensor], + sequence_length: int, + dtype: torch.dtype, + device: torch.device, + is_causal: bool = True, +) -> Optional[torch.Tensor]: + """ + Creates 4D attention mask and combines causal and padding masks if needed. + + Args: + attention_mask: Optional tensor of shape (batch_size, seq_length) containing padding mask + sequence_length: Length of the sequence + dtype: Data type of the mask + device: Device of the mask + is_causal: Whether to apply causal masking + + Returns: + 4D attention mask of shape (batch_size, 1, seq_length, seq_length) + """ + # Handle padding mask + if attention_mask is not None: + # Convert 2D padding mask to 4D attention mask + attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = attention_mask * timesfm_get_large_negative_number(dtype) + + # Create causal mask if needed + if is_causal: + causal_mask = torch.triu( + torch.ones((sequence_length, sequence_length), dtype=dtype, device=device) + * timesfm_get_large_negative_number(dtype), + diagonal=1, + ) + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) + + # Combine with padding mask if it exists + if attention_mask is not None: + attention_mask = torch.minimum(attention_mask, causal_mask) + else: + attention_mask = causal_mask + + return attention_mask + + def timesfm_masked_mean_std(inputs: torch.Tensor, padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Calculates mean and standard deviation of `inputs` across axis 1. @@ -551,79 +608,6 @@ def timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Te return [smoothed_arr, arr - smoothed_arr] -def timesfm_get_large_negative_number(dtype: torch.dtype) -> torch.Tensor: - """Returns a large negative value for the given dtype.""" - if dtype.is_floating_point: - dtype_max = torch.finfo(dtype).max - else: - dtype_max = torch.iinfo(dtype).max - return torch.tensor(-0.7 * dtype_max, dtype=dtype) - - -def timesfm_causal_mask(input_t: torch.Tensor) -> torch.Tensor: - """Computes and returns causal mask. - - Args: - input_t: A torch.Tensor of shape [B, T, D]. - - Returns: - An attention_mask torch.Tensor of shape [1, 1, T, T]. Attention mask has - already been converted to large negative values. - """ - assert input_t.dtype.is_floating_point, input_t.dtype - large_negative_number = timesfm_get_large_negative_number(input_t.dtype) - t = input_t.shape[1] - col_idx = torch.arange(t).unsqueeze(0).repeat(t, 1) - row_idx = torch.arange(t).unsqueeze(1).repeat(1, t) - mask = (row_idx < col_idx).to(input_t.dtype) * large_negative_number - return mask.unsqueeze(0).unsqueeze(0).to(input_t.device) # Equivalent to jnp.newaxis - - -def timesfm_convert_paddings_to_mask(paddings: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: - """Converts binary paddings to a logit mask ready to add to attention matrix. - - Args: - paddings: binary torch.Tensor of shape [B, T], with 1 denoting padding - token. - dtype: data type of the input. - - Returns: - A torch.Tensor of shape [B, 1, 1, T] ready to add to attention logits. - """ - attention_mask = paddings.detach().clone() - attention_mask = attention_mask[:, None, None, :] # Equivalent to jnp.newaxis - attention_mask *= timesfm_get_large_negative_number(dtype) - return attention_mask - - -def timesfm_merge_masks(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: - """Merges 2 masks. - - logscale mask is expected but 0/1 mask is also fine. - - Args: - a: torch.Tensor of shape [1|B, 1, 1|T, S]. - b: torch.Tensor of shape [1|B, 1, 1|T, S]. - - Returns: - torch.Tensor of shape [1|B, 1, 1|T, S]. - """ - - def expand_t(key_mask): - query_mask = key_mask.transpose(-1, -2) # Equivalent of jnp.transpose - return torch.minimum(query_mask, key_mask) - - if a.shape[2] != b.shape[2]: - if a.shape[2] == 1: - a = expand_t(a) - else: - assert b.shape[2] == 1 - b = expand_t(b) - - assert a.shape[1:] == b.shape[1:], f"a.shape={a.shape}, b.shape={b.shape}." - return torch.minimum(a, b) # Element-wise minimum, similar to jnp.minimum - - TIMESFM_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads From ca21a2bfa0827f4f9f5b11b96b42818b86c12a7e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 1 Mar 2025 13:04:07 +0100 Subject: [PATCH 169/242] we do not have generative model classes --- tests/models/timesfm/test_modeling_timesfm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index c38b1ab07dd0..c88a376c896c 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -131,7 +131,7 @@ def prepare_config_and_inputs_for_common(self): @require_torch class TimesFmModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (TimesFmModelForPrediction,) if is_torch_available() else () - all_generative_model_classes = (TimesFmModelForPrediction,) if is_torch_available() else () + all_generative_model_classes = () all_parallelizable_model_classes = () fx_compatible = False test_pruning = False From bac7f24ac40aa8dccd5d84c73bb604e5636c28aa Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 1 Mar 2025 13:36:14 +0100 Subject: [PATCH 170/242] use Cache --- .../models/timesfm/modeling_timesfm.py | 56 +++++++++---------- 1 file changed, 26 insertions(+), 30 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index a2bc773ffdca..7c654f8247e0 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -22,6 +22,7 @@ import torch.nn as nn import torch.nn.functional as F +from ...cache_utils import Cache from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -217,8 +218,8 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - kv_write_indices: Optional[torch.Tensor] = None, - kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, output_attentions: bool = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states_shape = hidden_states.shape @@ -236,16 +237,14 @@ def forward( # Write new kv cache. # [batch_size, input_len, n_local_kv_heads, head_dim] - if kv_cache is not None and kv_write_indices is not None: - k_cache, v_cache = kv_cache - k_cache.index_copy_(1, kv_write_indices, xk) - v_cache.index_copy_(1, kv_write_indices, xv) - - key = k_cache - value = v_cache + if past_key_value is not None and cache_position is not None: + past_key_value.update(xk, xv, cache_position) + key = past_key_value.get_seq_length() + value = past_key_value.value_states else: key = xk value = xv + if self.num_kv_heads != self.num_heads: # [batch_size, max_seq_len, n_local_heads, head_dim] key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) @@ -267,7 +266,6 @@ def forward( # [batch_size, n_local_heads, input_len, head_dim] output = torch.matmul(scores, v) - # return scores, output.transpose(1, 2).contiguous() # [batch_size, input_len, hidden_dim] output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) @@ -286,16 +284,16 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - kv_write_indices: Optional[torch.Tensor] = None, - kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, output_attentions: bool = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: if output_attentions: return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, - kv_write_indices=kv_write_indices, - kv_cache=kv_cache, + past_key_value=past_key_value, + cache_position=cache_position, output_attentions=output_attentions, ) @@ -315,12 +313,10 @@ def forward( xq = self._scale_query(xq) # Handle KV cache - if kv_cache is not None and kv_write_indices is not None: - k_cache, v_cache = kv_cache - k_cache.index_copy_(1, kv_write_indices, xk) - v_cache.index_copy_(1, kv_write_indices, xv) - key = k_cache - value = v_cache + if past_key_value is not None and cache_position is not None: + past_key_value.update(xk, xv, cache_position) + key = past_key_value.key_states + value = past_key_value.value_states else: key = xk value = xv @@ -384,18 +380,18 @@ def forward( hidden_states: torch.Tensor, attention_mask: torch.Tensor, paddings: torch.Tensor, - kv_write_indices: Optional[torch.Tensor] = None, - kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, output_attentions: bool = False, - ) -> torch.Tensor: + ) -> tuple[Optional[torch.Tensor], torch.Tensor]: # Self Attention residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states, scores = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, - kv_write_indices=kv_write_indices, - kv_cache=kv_cache, + past_key_value=past_key_value, + cache_position=cache_position, output_attentions=output_attentions, ) hidden_states = residual + hidden_states @@ -417,8 +413,8 @@ def forward( self, hidden_states: torch.Tensor, paddings: torch.Tensor, - kv_write_indices: Optional[torch.Tensor] = None, - kv_caches: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, + past_key_values: Optional[List[Cache]] = None, + cache_position: Optional[torch.LongTensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, ) -> BaseModelOutput: @@ -436,13 +432,13 @@ def forward( for i in range(len(self.layers)): layer = self.layers[i] - kv_cache = kv_caches[i] if kv_caches is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None scores, hidden_states = layer( hidden_states=hidden_states, attention_mask=attention_mask, paddings=paddings, - kv_write_indices=kv_write_indices, - kv_cache=kv_cache, + past_key_value=past_key_value, + cache_position=cache_position, output_attentions=output_attentions, ) if output_attentions: From f5a3570d4ab6a58e87c0ca80ee67492691940925 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 1 Mar 2025 13:53:22 +0100 Subject: [PATCH 171/242] return past_key_values --- .../models/timesfm/modeling_timesfm.py | 36 ++++++++++++++----- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 7c654f8247e0..8a44e0dabdd6 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -16,14 +16,14 @@ import math from dataclasses import dataclass -from typing import List, Optional, Sequence, Tuple, Union +from typing import List, Optional, Sequence, Union import torch import torch.nn as nn import torch.nn.functional as F from ...cache_utils import Cache -from ...modeling_outputs import BaseModelOutput +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import ( add_code_sample_docstrings, @@ -49,10 +49,14 @@ class TimesFmOutput(BaseModelOutput): The mean of the time series inputs. scale (`torch.Tensor` of shape `(batch_size,)`): The scale of the time series inputs. + past_key_values (`List[Cache]`, *optional*): + Contains the precomputed key and value hidden states of the attention blocks used for + faster decoding. Can be used as a cache for future predictions. """ loc: Optional[torch.Tensor] = None scale: Optional[torch.Tensor] = None + past_key_values: Optional[List[Cache]] = None @dataclass @@ -65,11 +69,15 @@ class TimesFmOutputForPrediction(BaseModelOutput): The full predictions of the time series including the mean and the quantiles. loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `future_target` is provided): The loss of the TimesFM model. + past_key_values (`List[Cache]`, *optional*): + Contains the precomputed key and value hidden states of the attention blocks used for + faster decoding. Can be used as a cache for future predictions. """ mean_predictions: Optional[torch.Tensor] = None full_predictions: Optional[torch.Tensor] = None loss: Optional[Union[torch.Tensor, float]] = None + past_key_values: Optional[List[Cache]] = None class TimesFmMLP(nn.Module): @@ -239,7 +247,7 @@ def forward( # [batch_size, input_len, n_local_kv_heads, head_dim] if past_key_value is not None and cache_position is not None: past_key_value.update(xk, xv, cache_position) - key = past_key_value.get_seq_length() + key = past_key_value.key_states value = past_key_value.value_states else: key = xk @@ -417,7 +425,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, - ) -> BaseModelOutput: + ) -> BaseModelOutputWithPast: # Convert paddings to attention mask and combine with causal mask attention_mask = _prepare_4d_attention_mask( attention_mask=paddings, @@ -429,10 +437,11 @@ def forward( all_attentions = [] all_hidden_states = [] + current_past_key_values = [] if past_key_values is None else past_key_values for i in range(len(self.layers)): layer = self.layers[i] - past_key_value = past_key_values[i] if past_key_values is not None else None + past_key_value = current_past_key_values[i] if i < len(current_past_key_values) else None scores, hidden_states = layer( hidden_states=hidden_states, attention_mask=attention_mask, @@ -445,11 +454,14 @@ def forward( all_attentions.append(scores) if output_hidden_states: all_hidden_states.append(hidden_states) + if past_key_values is None: + current_past_key_values.append(past_key_value) - return BaseModelOutput( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, - attentions=all_attentions, - hidden_states=all_hidden_states, + past_key_values=current_past_key_values if current_past_key_values else None, + attentions=all_attentions if output_attentions else None, + hidden_states=all_hidden_states if output_hidden_states else None, ) @@ -630,8 +642,12 @@ class TimesFmPreTrainedModel(PreTrainedModel): config_class = TimesFmConfig base_model_prefix = "timesfm" + _no_split_modules = ["TimesFmDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] main_input_name = "inputs" _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): if isinstance(module, nn.Embedding): @@ -861,6 +877,7 @@ def forward( attentions=transformer_output.attentions if output_attentions else None, loc=stats[0], scale=stats[1], + past_key_values=transformer_output.past_key_values, ) else: return ( @@ -869,6 +886,7 @@ def forward( transformer_output.attentions, stats[0], stats[1], + transformer_output.past_key_values, ) @@ -1122,6 +1140,7 @@ def forward( mean_predictions=mean_outputs, full_predictions=full_outputs, loss=loss, + past_key_values=decoder_output.past_key_values, ) else: return_tuple = [decoder_output.last_hidden_state] @@ -1130,6 +1149,7 @@ def forward( if output_attentions: return_tuple.append(decoder_output.attentions) return_tuple += [mean_outputs, full_outputs, loss] + return_tuple += [decoder_output.past_key_values] return tuple(return_tuple) From 7b00789f73d6980c51b8ff61aa86352ab891d318 Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Mon, 3 Mar 2025 14:12:14 -0800 Subject: [PATCH 172/242] modules initialized with config only --- .../models/timesfm/modeling_timesfm.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 8a44e0dabdd6..27080ae71a37 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -83,12 +83,11 @@ class TimesFmOutputForPrediction(BaseModelOutput): class TimesFmMLP(nn.Module): """Pax MLP in pytorch.""" - def __init__( - self, - hidden_size: int, - intermediate_size: int, - ): + def __init__(self, config: TimesFmConfig): super().__init__() + hidden_size = config.model_dim + intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(hidden_size, intermediate_size) self.down_proj = nn.Linear(intermediate_size, hidden_size) self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6) @@ -131,13 +130,13 @@ def forward(self, x): class TimesFmRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): + def __init__(self, config: TimesFmConfig): """ TimesFmRMSNorm is equivalent to T5LayerNorm """ super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps + self.weight = nn.Parameter(torch.ones(config.model_dim)) + self.variance_epsilon = config.rms_norm_eps def forward(self, hidden_states): input_dtype = hidden_states.dtype @@ -153,7 +152,7 @@ def extra_repr(self): class TimesFmPositionalEmbedding(nn.Module): """Generates position embedding for a given 1-d sequence.""" - def __init__(self, config: TimesFmConfig) -> None: + def __init__(self, config: TimesFmConfig): super().__init__() self.min_timescale = config.min_timescale self.max_timescale = config.max_timescale @@ -380,8 +379,8 @@ def __init__(self, config: TimesFmConfig): attention_class = TIMESFM_ATTENTION_CLASSES[config._attn_implementation] self.self_attn = attention_class(config) - self.mlp = TimesFmMLP(config.model_dim, config.intermediate_size) - self.input_layernorm = TimesFmRMSNorm(config.model_dim, eps=config.rms_norm_eps) + self.mlp = TimesFmMLP(config) + self.input_layernorm = TimesFmRMSNorm(config) def forward( self, From 019c6a243bf712e3e2dcb490c41e1303e219b9c7 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 4 Mar 2025 17:32:49 +0100 Subject: [PATCH 173/242] update year --- docs/source/en/model_doc/timesfm.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/timesfm.md b/docs/source/en/model_doc/timesfm.md index 88366594e803..2d28c9897ea8 100644 --- a/docs/source/en/model_doc/timesfm.md +++ b/docs/source/en/model_doc/timesfm.md @@ -1,4 +1,4 @@ -