-
Notifications
You must be signed in to change notification settings - Fork 33.1k
Add TimesFM Time Series Forecasting Model #34082
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
282 commits
Select commit
Hold shift + click to select a range
f539965
initial documentation
kashif f95d6ee
rename mask to attention_mask
kashif d475529
smaller tests
kashif f7c1fe0
fixup
kashif 5a808be
fix copies
kashif 6dbbd80
move to time series section
kashif bfa9302
sort docs
kashif eb5807e
isort fix
kashif be32365
batch_size is not a configuration
kashif c4a3610
rename to TimesFMModelForPrediction
kashif 56a5606
initial script
kashif 01756ae
add check_outputs
kashif 942c23c
remove dropout_rate
kashif e7650bd
works with torch.Tensor inputs
kashif c523f64
rename script
kashif e64f562
fix docstrings
kashif f5dbab9
fix freq when window_size is given
kashif a8dcfa9
add loss
kashif 5fb1fe0
fix _quantile_loss
kashif 1445fe5
formatting
kashif b84c188
Merge branch 'main' into timesfm
kashif 1c1804e
fix isort
kashif 35a7e9f
add weight init
kashif 4bfc95c
Merge branch 'main' into timesfm
kashif 6c4dded
add support for sdpa and flash_attention_2
kashif 82c697c
fixes for flash_attention
kashif 3bedf98
formatting
kashif 2b4f55c
remove flash_attention
kashif 5c4c591
fix tests
kashif f924a31
fix file name
kashif 84f763e
fix quantile loss
kashif ee1e289
added initial TimesFMModelIntegrationTests
kashif b9e9633
fix formatting
kashif d3753ff
Merge branch 'main' into timesfm
kashif fce6cf4
Merge branch 'main' into timesfm
kashif c9dede6
fix import order
kashif bc67797
fix _quantile_loss
kashif 61d5e89
add doc for SDPA
kashif 5ea1698
Merge branch 'main' into timesfm
kashif ece0896
use timesfm 2.0
kashif 21e3236
bug fix in timesfm decode function.
rajatsen91 f173a8e
compare mean forecasts
kashif 4e5196a
Merge branch 'main' into timesfm
kashif b83023c
refactor type hints, use CamelCase
jinan-zhou b21ec50
consolidate decode func
jinan-zhou 591874f
Merge branch 'main' into timesfm
kashif e8dfab0
Merge branch 'main' into timesfm
kashif e162102
more readable code for weight conversion
jinan-zhou 8d614ae
Merge branch 'main' into timesfm
kashif e7531e1
fix-copies
kashif 0cfb2c3
simpler init
kashif 2e29e5f
renaem TimesFmMLP
kashif 5dc2927
use T5LayerNorm
kashif 7180f79
fix tests
kashif cdb4239
use initializer_range
kashif c48d673
TimesFmModel instead of TimesFmDecoder
kashif ce5f216
TimesFmPositionalEmbedding takes config for its init
kashif 9453ed9
2.0-500m-pytorch default configs
kashif 61c96fd
use TimesFmModel
kashif 9538c1d
fix formatting
kashif bfa69e7
ignore TimesFmModel for testing
kashif 80d8809
Merge branch 'main' into timesfm
kashif c34286f
fix docstring
kashif 72ae8f5
Merge branch 'main' into timesfm
kashif e401b33
override generate as its not needed
kashif 85446e3
add doc strings
kashif c410cde
fix logging
kashif 8d5a210
add docstrings to output data classes
kashif c2625e0
initial copy from t5
kashif f43a0df
added config and attention layers
kashif 8bbda06
add TimesFMPositionalEmbedding
kashif 5178c11
calcuate scale_factor once
kashif 95a06a9
add more configs and TimesFMResidualBlock
kashif 3be5893
fix input_dims
kashif 9fb8bf8
standardize code format with black
jinan-zhou f79803c
remove unneeded modules
jinan-zhou a81e99b
TimesFM Model
jinan-zhou 1ec48c7
order of imports
kashif 8abfc2e
copy from Google official implementation
jinan-zhou 7e0305a
remove covariate forecasting
jinan-zhou c042a9d
Adapting TimesFM to HF format
jinan-zhou a52eeca
restructing in progress
jinan-zhou c7f760e
adapted to HF convention
jinan-zhou d717132
timesfm test
jinan-zhou 72ffaaf
the model runs
jinan-zhou 3818ee4
fixing unit tests
jinan-zhou 0013655
fixing unit tests in progress
jinan-zhou 6419285
add post_init
kashif 7cd2e41
do not change TimesFMOutput
kashif 47affe8
fixing unit tests
jinan-zhou bbf738c
all unit tests passed
jinan-zhou bb2a850
remove timesfm_layers
kashif c55088d
add intermediate_size and initialize with config
kashif fd270d9
initial documentation
kashif 9bb5a49
rename mask to attention_mask
kashif 5376dd7
smaller tests
kashif 8edb51e
fixup
kashif e8e31cd
fix copies
kashif 5b18440
move to time series section
kashif 5ebeec2
sort docs
kashif f810125
isort fix
kashif 7e5921c
batch_size is not a configuration
kashif 906d6a8
rename to TimesFMModelForPrediction
kashif c30e748
initial script
kashif d7d3a13
add check_outputs
kashif c3fbff2
remove dropout_rate
kashif 9e6750c
works with torch.Tensor inputs
kashif b437e87
rename script
kashif 9f0f086
fix docstrings
kashif f9e5db8
fix freq when window_size is given
kashif c8703ff
add loss
kashif 8f6c2e1
fix _quantile_loss
kashif b319873
formatting
kashif 3bd0827
fix isort
kashif 0d4325e
add weight init
kashif 4212ef8
add support for sdpa and flash_attention_2
kashif 9739e4b
fixes for flash_attention
kashif 33cee01
formatting
kashif bce6405
remove flash_attention
kashif fb33f35
fix tests
kashif b41c368
fix file name
kashif 9aad101
fix quantile loss
kashif be8922f
added initial TimesFMModelIntegrationTests
kashif c468644
fix formatting
kashif 689d2a4
fix import order
kashif abb1c0a
fix _quantile_loss
kashif 686c71b
add doc for SDPA
kashif 91c50a4
use timesfm 2.0
kashif cef8510
bug fix in timesfm decode function.
rajatsen91 7c7e56f
compare mean forecasts
kashif 22bb7cf
refactor type hints, use CamelCase
jinan-zhou 53b290a
consolidate decode func
jinan-zhou c65e4b4
more readable code for weight conversion
jinan-zhou b428972
fix-copies
kashif ea05e27
simpler init
kashif 038859d
renaem TimesFmMLP
kashif ef59621
use T5LayerNorm
kashif d8c2e0d
fix tests
kashif a75b8e7
use initializer_range
kashif 5352cda
TimesFmModel instead of TimesFmDecoder
kashif f460370
TimesFmPositionalEmbedding takes config for its init
kashif 913f360
2.0-500m-pytorch default configs
kashif 02e62c6
use TimesFmModel
kashif 4466315
fix formatting
kashif df7bbb0
ignore TimesFmModel for testing
kashif c0a4f48
fix docstring
kashif 71bda44
override generate as its not needed
kashif b7e75e9
add doc strings
kashif f76116b
fix logging
kashif 0026ba6
add docstrings to output data classes
kashif 909fd6c
Merge branch 'timesfm' of https://github.com/kashif/transformers into…
kashif 380e6bf
add _CHECKPOINT_FOR_DOC
kashif 8deeb3e
fix comments
jinan-zhou 92e0b41
Revert "fix comments"
kashif 33fde14
add _prepare_4d_attention_mask
kashif 5f7bffb
Merge branch 'main' into timesfm
kashif ca21a2b
we do not have generative model classes
kashif bac7f24
use Cache
kashif f5a3570
return past_key_values
kashif a53195c
Merge branch 'main' into timesfm
kashif 7b00789
modules initialized with config only
jinan-zhou 8342c11
Merge branch 'main' into timesfm
jinan-zhou 921c0bd
Merge branch 'main' into timesfm
kashif 019c6a2
update year
kashif 32065cc
Update docs/source/en/model_doc/timesfm.md
kashif 4a1687b
add layer_idx to cache
kashif e6d77dd
modular timesfm
kashif c236313
fix test
kashif e383fcb
Merge branch 'huggingface:main' into timesfm
jinan-zhou b0354f0
unwrap sequential class
jinan-zhou ace1363
fix toctree
jinan-zhou df91360
remove TimesFmOnnxConfig
kashif 5dc0c38
Merge branch 'timesfm' of https://github.com/kashif/transformers into…
kashif fce6d1f
fix modular
kashif 9da15fd
remove TimesFmStackedDecoder
kashif 94126c6
split qkv layer into individual layers
kashif 006e97a
rename projection layers
kashif 9f0b4c4
use ALL_ATTENTION_FUNCTIONS
kashif b587467
is_causal is True
kashif 270d99b
Merge branch 'main' into timesfm
kashif 156051d
rename config
kashif cf733d0
Merge branch 'timesfm' of https://github.com/kashif/transformers into…
kashif dc46013
does not support flash_attn_2
kashif d9b1cca
formatting
kashif 6cc2ea6
fix typo in docsstring
kashif 66b0af6
Merge branch 'main' into timesfm
kashif 3c036f6
rename inputs
kashif 63dba1c
add time series mapping
kashif 768b5f5
Update src/transformers/models/olmo2/modeling_olmo2.py
kashif 1d20534
Update src/transformers/models/moonshine/modeling_moonshine.py
kashif 620dcd0
use updated arguments
kashif 8da0298
fix class name
kashif 36f0298
add MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING
kashif 28671d9
isort
kashif 28c881b
Merge branch 'main' into timesfm
kashif 347e7e7
consolidate _preprocess into forward
jinan-zhou 41f6f55
fix a typo
jinan-zhou 8b46515
fix a typo
jinan-zhou c646a84
fix toc
kashif fc95e56
fix modular
kashif 4e2ad7e
Merge branch 'main' into timesfm
kashif 466e88c
Merge branch 'main' into timesfm
kashif e0d163c
remove aaserts
kashif 2acc0ac
use self.config._attn_implementation
kashif 567a45d
move to _postprocess_output
kashif cf90818
remove timesfm_get_large_negative_number
kashif 666fc92
use view unstead of multiple unsqueeze
kashif d7429e9
make helpers static methods of the Model
kashif 9d9d6f2
use to_tuple
kashif a6d049f
use to_tuple if not return_dict
kashif 942b607
remove unused intitialization block as its incorporated in nn.Linear
kashif d6314c8
remove unused num_key_value_groups
kashif ca68d43
use the same convention as the masking method
kashif 8d158f3
Merge branch 'main' into timesfm
kashif 5191573
update modular
kashif f47a1e7
do not use unsqueeze
kashif 936a2d6
use view instead of unsqueeze
kashif a88dae7
Merge branch 'main' into timesfm
kashif 271b169
Merge branch 'main' into timesfm
kashif 5b40f25
use buffer for inv_timescales
kashif a7f85ce
formatting
kashif 9685037
modular conversion
kashif b88a984
remove unneeded intialization
kashif 49eed00
add missing docstrings
kashif 649f2a6
remove cache
kashif a2e3f05
Merge branch 'main' into timesfm
kashif 07669d2
use simple_eager_attention_forward
kashif 08df212
support tp_plan
kashif def36c4
support for flex and flash attention masks
kashif 5cc47cd
Revert "support for flex and flash attention masks"
kashif 5e3a5e2
fix device
kashif 7da546f
fix tests on gpu
kashif debb032
remove unsued large model test
kashif 2a0c209
removed unneeded comments
kashif b1c3c49
Merge branch 'main' into timesfm
kashif 87e8b12
add example usage
kashif 0493f61
Merge branch 'main' into timesfm
kashif 70c3cb5
fix style
kashif 76f72fb
add import
kashif aa721d4
Merge branch 'main' into timesfm
kashif e7882d7
Update docs/source/en/model_doc/timesfm.md
kashif a86136d
inherit from LlamaRMSNorm
kashif 60e7e65
use can_return_tuple decorator
kashif a5b9010
remvoe return_dict
kashif ca86584
fix year
kashif 0b711f1
Merge branch 'main' into timesfm
kashif 531f8e3
Merge branch 'main' into timesfm
Cyrilvallez 9b76c0b
Update docs/source/en/model_doc/timesfm.md
kashif 0161ca9
pretrained does not inherit from GenerationMixin
kashif fa53c52
use model for integration test
kashif File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| <!--Copyright 2025 The HuggingFace Team. All rights reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
| the License. You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
| specific language governing permissions and limitations under the License. | ||
|
|
||
| ⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
| rendered properly in your Markdown viewer. | ||
|
|
||
| --> | ||
|
|
||
| # TimesFM | ||
|
|
||
| <div class="flex flex-wrap space-x-1"> | ||
| <img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white"> | ||
| </div> | ||
|
|
||
| ## Overview | ||
|
|
||
| TimesFM (Time Series Foundation Model) is a pretrained time-series foundation model proposed in [A decoder-only foundation model for time-series forecasting](https://huggingface.co/papers/2310.10688) by Abhimanyu Das, Weihao Kong, Rajat Sen, and Yichen Zhou. It is a decoder only model that uses non-overlapping patches of time-series data as input and outputs some output patch length prediction in an autoregressive fashion. | ||
|
|
||
|
|
||
| The abstract from the paper is the following: | ||
|
|
||
| *Motivated by recent advances in large language models for Natural Language Processing (NLP), we design a time-series foundation model for forecasting whose out-of-the-box zero-shot performance on a variety of public datasets comes close to the accuracy of state-of-the-art supervised forecasting models for each individual dataset. Our model is based on pretraining a patched-decoder style attention model on a large time-series corpus, and can work well across different forecasting history lengths, prediction lengths and temporal granularities.* | ||
|
|
||
|
|
||
| This model was contributed by [kashif](https://huggingface.co/kashif). | ||
| The original code can be found [here](https://github.com/google-research/timesfm). | ||
|
kashif marked this conversation as resolved.
|
||
|
|
||
|
|
||
| To use the model: | ||
|
|
||
| ```python | ||
| import torch | ||
| from transformers import TimesFmModelForPrediction | ||
|
|
||
|
|
||
| model = TimesFmModelForPrediction.from_pretrained( | ||
| "google/timesfm-2.0-500m-pytorch", | ||
| torch_dtype=torch.bfloat16, | ||
| attn_implementation="sdpa", | ||
| device_map="cuda" if torch.cuda.is_available() else None | ||
| ) | ||
|
|
||
|
|
||
| # Create dummy inputs | ||
| forecast_input = [ | ||
| np.sin(np.linspace(0, 20, 100)), | ||
| np.sin(np.linspace(0, 20, 200)), | ||
| np.sin(np.linspace(0, 20, 400)), | ||
| ] | ||
| frequency_input = [0, 1, 2] | ||
|
|
||
| # Convert inputs to sequence of tensors | ||
| forecast_input_tensor = [ | ||
| torch.tensor(ts, dtype=torch.bfloat16).to("cuda" if torch.cuda.is_available() else "cpu") | ||
| for ts in forecast_input | ||
| ] | ||
| frequency_input_tensor = torch.tensor(frequency_input, dtype=torch.long).to( | ||
| "cuda" if torch.cuda.is_available() else "cpu" | ||
| ) | ||
|
|
||
| # Get predictions from the pre-trained model | ||
| with torch.no_grad(): | ||
| outputs = model(past_values=forecast_input_tensor, freq=frequency_input_tensor, return_dict=True) | ||
| point_forecast_conv = outputs.mean_predictions.float().cpu().numpy() | ||
| quantile_forecast_conv = outputs.full_predictions.float().cpu().numpy() | ||
| ``` | ||
|
|
||
| ## TimesFmConfig | ||
|
|
||
| [[autodoc]] TimesFmConfig | ||
|
|
||
| ## TimesFmModel | ||
|
|
||
| [[autodoc]] TimesFmModel | ||
| - forward | ||
|
|
||
| ## TimesFmModelForPrediction | ||
|
|
||
| [[autodoc]] TimesFmModelForPrediction | ||
| - forward | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| # Copyright 2025 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| from ...utils import _LazyModule | ||
| from ...utils.import_utils import define_import_structure | ||
|
|
||
|
|
||
| if TYPE_CHECKING: | ||
| from .configuration_timesfm import * | ||
| from .modeling_timesfm import * | ||
| else: | ||
| import sys | ||
|
|
||
| _file = globals()["__file__"] | ||
| sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) |
129 changes: 129 additions & 0 deletions
129
src/transformers/models/timesfm/configuration_timesfm.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| # coding=utf-8 | ||
| # Copyright 2025 Google LLC and HuggingFace Inc. team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """TimesFM model configuration""" | ||
|
|
||
| from typing import List | ||
|
|
||
| from ...configuration_utils import PretrainedConfig | ||
| from ...utils import logging | ||
|
|
||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
||
| class TimesFmConfig(PretrainedConfig): | ||
| r""" | ||
| This is the configuration class to store the configuration of a [`TimesFmModelForPrediction`] or a [`TFTimesFmModel`]. It is used to | ||
| instantiate a TimesFM model according to the specified arguments, defining the model architecture. Instantiating a | ||
| configuration with the defaults will yield a similar configuration to that of the TimesFM | ||
| [google/timesfm-2.0-500m-pytorch](https://huggingface.co/google/timesfm-2.0-500m-pytorch) architecture. | ||
|
|
||
| Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the | ||
| documentation from [`PretrainedConfig`] for more information. | ||
|
|
||
| Arguments: | ||
| patch_length (`int`, *optional*, defaults to 32): | ||
| The length of one patch in the input sequence. | ||
| context_length (`int`, *optional*, defaults to 512): | ||
| The length of the input context. | ||
| horizon_length (`int`, *optional*, defaults to 128): | ||
| The length of the prediction horizon. | ||
| freq_size (`int`, *optional*, defaults to 3): | ||
| The number of frequency embeddings. | ||
| num_hidden_layers (`int`, *optional*, defaults to 50): | ||
| Number of Transformer layers. | ||
| hidden_size (`int`, *optional*, defaults to 1280): | ||
| Size of the hidden layers in the feed-forward networks. | ||
| intermediate_size (`int`, *optional*, defaults to 1280): | ||
| Dimension of the MLP representations. | ||
| head_dim (`int`, *optional*, defaults to 80): | ||
| Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will | ||
| be defined as `num_attention_heads * head_dim`. | ||
| num_attention_heads (`int`, *optional*, defaults to 16): | ||
| Number of attention heads for each attention layer in the Transformer encoder. | ||
| tolerance (`float`, *optional*, defaults to 1e-06): | ||
| The tolerance for the quantile loss. | ||
| rms_norm_eps (`float`, *optional*, defaults to 1e-06): | ||
| The epsilon used by the RMS normalization layers. | ||
| quantiles (`List[float]`, *optional*, defaults to `[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]`): | ||
| The quantiles to predict. | ||
| pad_val (`float`, *optional*, defaults to 1123581321.0): | ||
| The value used to pad the predictions. | ||
| attention_dropout (`float`, *optional*, defaults to 0.0): | ||
| The dropout probability for the attention scores. | ||
| use_positional_embedding (`bool`, *optional*, defaults to `False`): | ||
| Whether to add positional embeddings. | ||
| initializer_range (`float`, *optional*, defaults to 0.02): | ||
| The standard deviation of the truncated_normal_initializer for initializing all weight matrices. | ||
| min_timescale (`int`, *optional*, defaults to 1): | ||
| The start of the geometric positional index. Determines the periodicity of | ||
| the added signal. | ||
| max_timescale (`int`, *optional*, defaults to 10000): | ||
| The end of the geometric positional index. Determines the frequency of the | ||
| added signal. | ||
| """ | ||
|
|
||
| model_type = "timesfm" | ||
| keys_to_ignore_at_inference = [] | ||
| is_encoder_decoder = False | ||
|
|
||
| def __init__( | ||
| self, | ||
| patch_length: int = 32, | ||
| context_length: int = 512, | ||
| horizon_length: int = 128, | ||
| freq_size: int = 3, | ||
| num_hidden_layers: int = 50, | ||
| hidden_size: int = 1280, | ||
| intermediate_size: int = 1280, | ||
| head_dim: int = 80, | ||
| num_attention_heads: int = 16, | ||
| tolerance: float = 1e-6, | ||
| rms_norm_eps: float = 1e-6, | ||
| quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], | ||
| pad_val: float = 1123581321.0, | ||
| attention_dropout: float = 0.0, | ||
| use_positional_embedding: bool = False, | ||
| initializer_range: float = 0.02, | ||
| min_timescale: int = 1, | ||
| max_timescale: int = 10_000, | ||
| **kwargs, | ||
| ): | ||
| self.patch_length = patch_length | ||
| self.context_length = context_length | ||
| self.horizon_length = horizon_length | ||
| self.quantiles = quantiles | ||
| self.pad_val = pad_val | ||
| self.freq_size = freq_size | ||
| self.hidden_size = hidden_size | ||
| self.intermediate_size = intermediate_size | ||
| self.head_dim = head_dim | ||
| self.num_hidden_layers = num_hidden_layers | ||
| self.num_attention_heads = num_attention_heads | ||
| self.tolerance = tolerance | ||
| self.rms_norm_eps = rms_norm_eps | ||
| self.attention_dropout = attention_dropout | ||
| self.use_positional_embedding = use_positional_embedding | ||
| self.initializer_range = initializer_range | ||
| self.min_timescale = min_timescale | ||
| self.max_timescale = max_timescale | ||
|
|
||
| super().__init__( | ||
| is_encoder_decoder=self.is_encoder_decoder, | ||
| **kwargs, | ||
| ) | ||
|
|
||
|
|
||
| __all__ = ["TimesFmConfig"] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.