-
Notifications
You must be signed in to change notification settings - Fork 43
Add stochastic mixer for supernet training #373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
43729b1
Add stochastic mixer for supernet training
tscholak 8b1eb08
Fix stochastic mixer test failures
tscholak 8ada30b
Fix stochastic mixer checkpoint conversion
tscholak cd1dbf8
Handle lossy HF conversions for stochastic mixer
tscholak d0fd648
Merge remote-tracking branch 'origin/main' into stochastic-mixer
tscholak d693f74
Clean up extra blank line in huggingface.py
tscholak 6962de9
Apply pre-commit formatting
tscholak a96c0cb
Refactor stochastic mixer: set main_mixer_name in validation, preproc…
tscholak 735ee3f
wip
tscholak aed779c
resolve merge conflicts
tscholak 982d409
Implement full stochastic mixer support in Apriel HuggingFace format
tscholak 0d8ab4d
Add Apriel2 checkpoint format and fix weight tying
tscholak bcd93b2
Optimize Apriel2: compute position embeddings and masks per unique block
tscholak ebe75c4
Add HuggingFace generation and caching improvements to Apriel2
tscholak ffd55e5
Add Apriel2DynamicCache for hybrid attention/SSM layer support
tscholak fe259c3
Add Mamba incremental generation support to Apriel2
tscholak 708917d
Add GatedDeltaNet support via Qwen3NextGatedDeltaNet wrapper
tscholak 77ceae2
Standardize naming: recurrent_states and Apriel2 prefixes
tscholak ec95ccc
Remove debug print statements and irrelevant changes
tscholak 571fede
Remove stochastic mixer support from apriel conversion
tscholak 8e7c154
Remove trivial formatting change from apriel_hybrid_ssm config
tscholak 4d0a01b
Remove test changes for lossy HF conversion
tscholak 71cf778
Revert trivial setup.py formatting and restore .eval() calls in tests
tscholak 75847d0
Rename SamplingStrategy to StochasticMixerSamplingStrategy
tscholak eacdf61
Use normalize_probabilities for sampling weights validation
tscholak 192e985
Remove tools/supernet_beam_search.py
tscholak 2fe9596
Fix stochastic mixer sampling to be consistent across all ranks
tscholak acb4751
Add Apriel2Cache with JetNemotron pattern and HF Cache compliance
tscholak a1d5f07
Add pytest-style test structure for Apriel2 with cache bug fixes
tscholak e044282
Add comprehensive Apriel2 modeling tests with cache verification
tscholak e830cc5
Add comprehensive mixer testing: all 4 types + switching behavior
tscholak f113d8d
Refactor Mamba availability checking and add CPU fallbacks
tscholak a8ccc1b
Apply code formatting to Apriel2 cache and config
tscholak a265e8c
Merge main into stochastic-mixer
tscholak a1c94b7
Fix StochasticMixer bugs after merge
tscholak a847129
Add seed_shift to StochasticMixer for reproducibility
tscholak File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,6 +28,7 @@ venv.bak/ | |
| # Project specifics | ||
| /.idea/ | ||
| /.vscode/ | ||
| /.devcontainer/ | ||
|
|
||
| # Devenv | ||
| .devenv* | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,173 @@ | ||
| import logging | ||
| import typing | ||
|
|
||
| import torch | ||
|
|
||
| from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig | ||
| from fast_llm.engine.config_utils.tensor_dim import TensorDim | ||
| from fast_llm.engine.distributed.config import DistributedConfig | ||
| from fast_llm.engine.distributed.distributed import Distributed | ||
| from fast_llm.layers.common.peft.config import PeftConfig | ||
| from fast_llm.layers.decoder.block import BlockWithBias | ||
| from fast_llm.layers.decoder.config import ( | ||
| StochasticMixerConfig, | ||
| StochasticMixerKwargs, | ||
| StochasticMixerSamplingStrategy, | ||
| ) | ||
| from fast_llm.tensor import TensorMeta | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class StochasticMixer[ConfigType: StochasticMixerConfig](BlockWithBias[ConfigType]): | ||
| """ | ||
| A mixer that stochastically samples from multiple mixer options during training. | ||
|
|
||
| In training mode, each forward pass randomly selects one mixer according to | ||
| the sampling strategy. In eval mode, uses the configured inference mixer. | ||
|
|
||
| This is useful for supernet training where you want to train multiple | ||
| architecture variants (e.g., attention vs. Mamba) with different data subsets. | ||
| """ | ||
|
|
||
| _config: ConfigType | ||
|
|
||
| def __init__( | ||
| self, | ||
| config: ConfigType, | ||
| distributed_config: DistributedConfig, | ||
| *, | ||
| hidden_dim: TensorDim, | ||
| lr_scale: float | None, | ||
| peft: PeftConfig | None, | ||
| return_bias: bool = True, | ||
| ): | ||
| super().__init__( | ||
| config, | ||
| distributed_config, | ||
| hidden_dim=hidden_dim, | ||
| lr_scale=lr_scale, | ||
| peft=peft, | ||
| return_bias=return_bias, | ||
| ) | ||
|
|
||
| # Initialize all mixers | ||
| self.mixers = torch.nn.ModuleDict( | ||
| { | ||
| name: mixer_config.get_layer( | ||
| distributed_config, | ||
| hidden_dim, | ||
| lr_scale=lr_scale, | ||
| peft=peft, | ||
| return_bias=return_bias, | ||
| ) | ||
| for name, mixer_config in self._config.mixers.items() | ||
| } | ||
| ) | ||
|
|
||
| if self._config.sampling_strategy == StochasticMixerSamplingStrategy.uniform: | ||
| self._sampling_probs = torch.ones(len(self.mixers), device="cpu") / len(self.mixers) | ||
| elif self._config.sampling_strategy == StochasticMixerSamplingStrategy.weighted: | ||
| if self._config.sampling_weights is None: | ||
| raise ValueError("sampling_weights must be provided when using weighted sampling strategy") | ||
| self._sampling_probs = torch.tensor( | ||
| [self._config.sampling_weights[name] for name in self.mixers.keys()], | ||
| dtype=torch.float32, | ||
| device="cpu", | ||
| ) | ||
| else: | ||
| raise NotImplementedError(f"Sampling strategy {self._config.sampling_strategy} not implemented") | ||
|
|
||
| logger.info( | ||
| f"Initialized StochasticMixer with {len(self.mixers)} mixers: " | ||
| f"{', '.join(f'{name}={type(mixer).__name__}' for name, mixer in self.mixers.items())} " | ||
| f"(main={self._config.main_mixer_name})" | ||
| ) | ||
|
|
||
| # Mark all mixer parameters with allow_no_grad since only one mixer | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One thing to keep in mind, all mixers will still go through gradient reduction and optimization step, the unused ones will get a zero gradient. Probably still ok, but it's suboptimal and could affect optimizer momenta. I don't think there is an easy fix though. |
||
| # is active per forward pass during training. Even though all mixers | ||
| # will eventually be trained, on any single forward pass, the non-selected | ||
| # mixers won't receive gradients. | ||
| for mixer in self.mixers.values(): | ||
| for param in mixer.parameters(recurse=True): | ||
| if hasattr(param, "allow_no_grad"): | ||
| param.allow_no_grad = True | ||
|
|
||
| def setup(self, distributed: Distributed) -> None: | ||
| """Setup all mixers with the distributed context.""" | ||
| super().setup(distributed) | ||
| for mixer in self.mixers.values(): | ||
| mixer.setup(distributed) | ||
|
|
||
| def _sample_mixer_name(self, kwargs: dict[str, typing.Any]) -> str: | ||
| if not self.training: | ||
| return self._config.main_mixer_name | ||
|
|
||
| generator = kwargs[StochasticMixerKwargs.generator] | ||
| mixer_idx = torch.multinomial(self._sampling_probs, num_samples=1, generator=generator).item() | ||
| return list(self.mixers.keys())[mixer_idx] | ||
|
|
||
| def _forward( | ||
| self, | ||
| input_: torch.Tensor, | ||
| kwargs: dict[str, typing.Any], | ||
| losses: dict[str, typing.Any] | None = None, | ||
| metrics: dict[str, typing.Any] | None = None, | ||
| ) -> tuple[torch.Tensor, torch.Tensor | None]: | ||
| mixer_name = self._sample_mixer_name(kwargs) | ||
|
|
||
| if self._debug.enabled: | ||
| logger.debug(f"StochasticMixer selecting mixer {mixer_name}: {type(self.mixers[mixer_name]).__name__}") | ||
|
|
||
| return self.mixers[mixer_name]._forward(input_, kwargs, losses, metrics) | ||
|
|
||
| def preprocess(self, kwargs: dict[str, typing.Any]) -> None: | ||
| from fast_llm.engine.distributed.config import MAX_SEED | ||
| from fast_llm.layers.block.config import BlockKwargs | ||
|
|
||
| iteration = kwargs[BlockKwargs.iteration] | ||
| generator = torch.Generator(device="cpu") | ||
| seed = (self._distributed_config.seed + self._config.seed_shift + iteration) % MAX_SEED | ||
| generator.manual_seed(seed) | ||
| kwargs[StochasticMixerKwargs.generator] = generator | ||
|
|
||
| for mixer in self.mixers.values(): | ||
| mixer.preprocess(kwargs) | ||
|
|
||
| def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: | ||
| """ | ||
| Return expected compute usage (weighted average of all mixers). | ||
|
|
||
| This gives a more accurate estimate than just using one mixer, | ||
| since during training we'll be using all of them according to | ||
| their sampling probabilities. | ||
| """ | ||
| usages = [mixer.get_compute_usage(input_, kwargs, config) for mixer in self.mixers.values()] | ||
|
|
||
| # Weight by sampling probability and return the expected value | ||
| expected_usage = sum(usage * prob.item() for usage, prob in zip(usages, self._sampling_probs)) | ||
|
|
||
| return int(expected_usage) | ||
|
|
||
| def get_loss_definitions(self, count: int = 1) -> list[LossDef]: | ||
|
tscholak marked this conversation as resolved.
|
||
| """ | ||
| Merge loss definitions from all mixers with namespacing. | ||
|
|
||
| Each mixer's losses are namespaced with the mixer name to avoid conflicts. | ||
| This ensures we allocate space for any auxiliary losses that any | ||
| of the mixers might need, even if multiple mixers have losses with the same name. | ||
| """ | ||
| all_losses = [] | ||
| for mixer_name, mixer in self.mixers.items(): | ||
| mixer_losses = mixer.get_loss_definitions(count=count) | ||
| # Namespace each loss with the mixer name to avoid conflicts | ||
| for loss_def in mixer_losses: | ||
| namespaced_loss = LossDef( | ||
| name=f"{mixer_name}/{loss_def.name}", | ||
| formatted_name=f"{mixer_name}/{loss_def.formatted_name}", | ||
| count=loss_def.count, | ||
| dtype=loss_def.dtype, | ||
| ) | ||
| all_losses.append(namespaced_loss) | ||
|
|
||
| return all_losses | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could just enforce this and make
main_mixer_namea cached property instead?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like it the way it is