Add create new mlp variation with two gates#795
Add create new mlp variation with two gates#795klei22 wants to merge 3 commits intoReaLLMASIC:masterfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds a new MLP variant (swiglu_2gate_pre_act) to the model-variation system and wires it into CLI/config + exploration tooling to enable parameter-matched comparisons against existing MLP/SwiGLU variants.
Changes:
- Introduces
SwiGLUTwoGatesPreActand registers it asswiglu_2gate_pre_actin the MLP factory. - Extends CLI argument choices to allow selecting the new MLP variant and makes the default
--devicemore explicit (cuda:0). - Adds a new exploration YAML to run parameter-matched sweeps comparing SwiGLU, the new 2-gate variant, and plain MLP activations on minipile.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| variations/mlp_variations.py | Implements and registers the new 2-gate pre-activation SwiGLU MLP module. |
| train_args.py | Exposes the new MLP variant in CLI choices and updates default device string. |
| explorations/mlp_equal_params_vs_swiglu_minipile.yaml | Adds an experiment grid to compare MLP activation variants under (approx.) parameter-matched settings. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if self.quantization_mlp_dict["quantize_mlp_act_activation_input"]: | ||
| num_bits = self.quantization_mlp_dict["quantize_mlp_act_activation_input_bits"] | ||
| quant_method = self.quantization_mlp_dict["activations_quant_method"] | ||
| x_main = fake_quantize_act(self, "mlp_act_activation_input", x_main, num_bits, quant_method, iter_num) | ||
|
|
||
| gate1 = self._up_project(x, self.c_fc_gate1) | ||
| gate2 = self._up_project(x, self.c_fc_gate2) | ||
|
|
||
| x_out = (x_main * gate1) * gate2 |
There was a problem hiding this comment.
In SwiGLUTwoGatesPreAct.forward, quantize_mlp_act_activation_input is applied to x_main, but the activation function is actually applied to x_out (after multiplying by both gates). This means the “activation input” quantization is quantizing the wrong tensor for this variant and will make quantization runs inconsistent with other MLPs (e.g., DualPathSwiglu quantizes the true pre-activation tensor). Apply this quantization step to x_out right before calling self.activation_variant(...) (and consider whether the gates themselves should be quantized under the same flag).
| if self.quantization_mlp_dict["quantize_mlp_act_activation_input"]: | |
| num_bits = self.quantization_mlp_dict["quantize_mlp_act_activation_input_bits"] | |
| quant_method = self.quantization_mlp_dict["activations_quant_method"] | |
| x_main = fake_quantize_act(self, "mlp_act_activation_input", x_main, num_bits, quant_method, iter_num) | |
| gate1 = self._up_project(x, self.c_fc_gate1) | |
| gate2 = self._up_project(x, self.c_fc_gate2) | |
| x_out = (x_main * gate1) * gate2 | |
| gate1 = self._up_project(x, self.c_fc_gate1) | |
| gate2 = self._up_project(x, self.c_fc_gate2) | |
| x_out = (x_main * gate1) * gate2 | |
| if self.quantization_mlp_dict["quantize_mlp_act_activation_input"]: | |
| num_bits = self.quantization_mlp_dict["quantize_mlp_act_activation_input_bits"] | |
| quant_method = self.quantization_mlp_dict["activations_quant_method"] | |
| x_out = fake_quantize_act(self, "mlp_act_activation_input", x_out, num_bits, quant_method, iter_num) |
This pull request introduces a new MLP variant called
swiglu_2gate_pre_act, expands the configuration and experiment setup to compare this and other MLP activation variants under parameter-matched conditions, and makes minor improvements to argument parsing and configuration handling. The main focus is on enabling and evaluating the new two-gate SwiGLU pre-activation architecture alongside other variants.Key changes:
New MLP variant and integration
SwiGLUTwoGatesPreActclass inmlp_variations.py, which introduces a SwiGLU variant with two gates applied before the non-linearity, including all relevant quantization, normalization, and offset logic. This is now available asswiglu_2gate_pre_actin the activation dictionary and MLP instantiation logic. [1] [2] [3]"swiglu_2gate_pre_act"to the list of supported MLP variants in the argument parser intrain_args.py, so it can be selected via CLI/config.Experimental configuration and comparison
mlp_equal_params_vs_swiglu_minipile.yamlthat sets up a comprehensive comparison of regular SwiGLU, dual-path, and parameter-matched plain MLP variants (with various activations) on the minipile dataset. This includes rationale for parameter matching, and defines multiple named groups for systematic exploration.Configuration and usability improvements
train_args.pyfrom'cuda'to'cuda:0'for more explicit device selection.l2_norm_print_dimsin the MLP config initialization for potential debugging or logging.