From c7ad36d3433f978bb06ad30da212134d769c9b86 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 21 Jan 2026 08:54:31 -0800 Subject: [PATCH] Revert "add mixed precision training for lm workload" --- .../finewebedu_lm/finewebedu_lm_jax/models.py | 91 +++++++-------- .../finewebedu_lm_jax/workload.py | 49 +------- .../finewebedu_lm_pytorch/models.py | 110 ++++++------------ .../finewebedu_lm_pytorch/workload.py | 95 ++++++--------- algoperf/workloads/finewebedu_lm/workload.py | 20 +--- pyproject.toml | 3 +- 6 files changed, 122 insertions(+), 246 deletions(-) diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py index 3419fe6fb..d08e9b7bf 100644 --- a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py @@ -8,7 +8,6 @@ import jax import jax.numpy as jnp -import jmp from flax import linen as nn @@ -27,24 +26,18 @@ class ModelConfig: use_residual_scaling: bool = True tie_embeddings: bool = True # Whether to tie input and output embed qknorm_epsilon: float = 1e-6 + + dtype: jnp.dtype = jnp.float32 attention_init: nn.initializers.Initializer = nn.initializers.normal( stddev=0.02 ) linear_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) embed_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) - param_dtype: jnp.dtype = jnp.float32 - compute_dtype: jnp.dtype = jnp.bfloat16 - output_dtype: jnp.dtype = jnp.bfloat16 def __post_init__(self): self.residual_init = nn.initializers.normal( stddev=0.02 / jnp.sqrt(2 * self.num_layers) ) - self.mp_policy = jmp.Policy( - compute_dtype=self.compute_dtype, - param_dtype=self.param_dtype, - output_dtype=self.output_dtype, - ) class Mlp(nn.Module): @@ -56,11 +49,7 @@ class Mlp(nn.Module): def __call__(self, x_BxLxD: jax.Array): cfg = self.cfg linear = partial( - nn.Dense, - kernel_init=cfg.linear_init, - use_bias=False, - dtype=cfg.compute_dtype, - param_dtype=cfg.param_dtype, + nn.Dense, kernel_init=cfg.linear_init, use_bias=False, dtype=cfg.dtype ) # Adjust hidden dimension to keep the number of parameters invariant to # the activation function used since the GLU MLP has 3 * hidden_dim * D @@ -76,8 +65,7 @@ def __call__(self, x_BxLxD: jax.Array): x_BxLxD = nn.Dense( cfg.model_dim, use_bias=False, - dtype=cfg.compute_dtype, - param_dtype=cfg.param_dtype, + dtype=cfg.dtype, kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init, @@ -108,7 +96,7 @@ def apply_rope(q, k, freqs_cis): def rotate_tensor(x): # Split into real and imaginary parts - x_r2 = x.reshape(*x.shape[:-1], -1, 2).astype(jnp.float32) + x_r2 = x.reshape(*x.shape[:-1], -1, 2) L = x.shape[1] freqs = freqs_cis[:, :L, :, :, :] @@ -121,7 +109,7 @@ def rotate_tensor(x): axis=-1, ) - return rotated_x_r2.reshape(*x.shape).astype(x.dtype) + return rotated_x_r2.reshape(*x.shape) # Apply rotation to Q and K separately rotated_q = rotate_tensor(q) @@ -153,8 +141,7 @@ def setup(self): features=(cfg.num_heads, self.Dh), kernel_init=cfg.attention_init, use_bias=False, - dtype=cfg.compute_dtype, - param_dtype=cfg.param_dtype, + dtype=cfg.dtype, ) self.multilinear_query = self.multilinear(name='query') self.multilinear_key = self.multilinear(name='key') @@ -163,9 +150,7 @@ def setup(self): seq_len = cfg.seq_len attn_scale0 = jnp.log2(seq_len**2 - seq_len) self.attn_scale = self.param( - 'attn_scale', - nn.initializers.constant(attn_scale0, dtype=cfg.compute_dtype), - (), + 'attn_scale', nn.initializers.constant(attn_scale0), () ) self.output_projection = nn.DenseGeneral( features=cfg.model_dim, @@ -175,8 +160,7 @@ def setup(self): if cfg.use_residual_scaling else cfg.linear_init, use_bias=False, - dtype=cfg.compute_dtype, - param_dtype=cfg.param_dtype, + dtype=cfg.dtype, ) def __call__(self, x_BxLxD: jax.Array): @@ -193,17 +177,32 @@ def __call__(self, x_BxLxD: jax.Array): # Apply QK normalization q_BxLxHxDh /= jnp.linalg.norm(q_BxLxHxDh, axis=-1, keepdims=True) + self.eps k_BxLxHxDh /= jnp.linalg.norm(k_BxLxHxDh, axis=-1, keepdims=True) + self.eps - q_BxLxHxDh *= self.attn_scale - out_BxLxHxDh = jax.nn.dot_product_attention( - query=q_BxLxHxDh, - key=k_BxLxHxDh, - value=v_BxLxHxDh, - is_causal=True, - scale=1.0, - implementation='cudnn' if cfg.compute_dtype is not jnp.float32 else None, - ) + + # Compute attention scores + att_BxHxLxL = jnp.einsum('...qhd,...khd->...hqk', q_BxLxHxDh, k_BxLxHxDh) + + # Causal attention mask + L = x_BxLxD.shape[1] + mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_)) + + # Apply mask and softmax + _NEG_INF = jnp.finfo(cfg.dtype).min + att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF) + att_BxHxLxL = ( + self.attn_scale * att_BxHxLxL + ) # Learned scaling factor for QK norm + att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1) + att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype) + + # Compute attention output + out_BxLxHxDh = jnp.einsum('...hqk,...khd->...qhd', att_BxHxLxL, v_BxLxHxDh) + + # Reshape and project output out_BxLxD = out_BxLxHxDh.reshape(*x_BxLxD.shape) + + # Output projection out_BxLxD = self.output_projection(out_BxLxD) + return out_BxLxD @@ -217,16 +216,16 @@ def __call__(self, in_BxLxD: jax.Array): cfg = self.docfg # x = x + attn( attn_norm(x) ) - x_BxLxD = nn.RMSNorm( - param_dtype=cfg.param_dtype, epsilon=cfg.rmsnorm_epsilon - )(in_BxLxD) + x_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + in_BxLxD + ) x_BxLxD = CausalAttn(cfg)(x_BxLxD) x_BxLxD += in_BxLxD # x = x + mlp( mlp_norm(x) ) - z_BxLxD = nn.RMSNorm( - param_dtype=cfg.param_dtype, epsilon=cfg.rmsnorm_epsilon - )(x_BxLxD) + z_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + x_BxLxD + ) z_BxLxD = Mlp(cfg)(z_BxLxD) return x_BxLxD + z_BxLxD @@ -243,24 +242,19 @@ def setup(self): num_embeddings=cfg.vocab_size, features=cfg.model_dim, embedding_init=cfg.embed_init, - dtype=cfg.compute_dtype, - param_dtype=cfg.param_dtype, ) self.blocks = [TBlock(cfg) for _ in range(cfg.num_layers)] - self.out_ln = nn.RMSNorm( - param_dtype=cfg.param_dtype, epsilon=cfg.rmsnorm_epsilon - ) + self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon) # Output projection - tied to input embeddings if configured if cfg.tie_embeddings: - self.output_proj = lambda x: self.embed.attend(x) + self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32)) else: self.output_proj = nn.Dense( cfg.vocab_size, kernel_init=cfg.embed_init, - dtype=cfg.compute_dtype, - param_dtype=cfg.param_dtype, + dtype=cfg.dtype, name='output_proj', ) @@ -363,7 +357,6 @@ def main(): # Make a prediction (forward pass) print('\nRunning forward pass...') - params, x_BxL = cfg.mp_policy.cast_to_compute((params, x_BxL)) logits = model.apply(params, x_BxL) # Print output shape and sample values diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py index 14366d9ea..ee4cffbbc 100644 --- a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py @@ -1,11 +1,9 @@ """LM workload implemented in Jax.""" -from functools import partial from typing import Any, Dict, Optional, Tuple import jax import jax.numpy as jnp -import jmp from algoperf import jax_sharding_utils, param_utils, spec from algoperf.workloads.finewebedu_lm.finewebedu_lm_jax.models import ( @@ -15,33 +13,10 @@ from algoperf.workloads.finewebedu_lm.input_pipeline import get_data_iter from algoperf.workloads.finewebedu_lm.workload import BaseLmWorkload -replicated_sharding = jax_sharding_utils.get_replicate_sharding() -batch_sharding = jax_sharding_utils.get_batch_dim_sharding() - -# Dtype mapping from string to JAX dtype -DTYPE_MAP = { - 'float32': jnp.float32, - 'float16': jnp.float16, - 'bfloat16': jnp.bfloat16, -} - class LmWorkload(BaseLmWorkload): """LM JAX workload.""" - # Convert dtype strings from base class to JAX dtypes - @property - def _compute_dtype(self) -> Any: - return DTYPE_MAP[self._compute_dtype_str] - - @property - def _param_dtype(self) -> Any: - return DTYPE_MAP[self._param_dtype_str] - - @property - def _output_dtype(self) -> Any: - return DTYPE_MAP[self._output_dtype_str] - def _build_input_queue( self, data_rng: jax.random.PRNGKey, @@ -78,14 +53,8 @@ def init_model_fn( num_layers=self._n_layers, # num layers vocab_size=self._vocab_size, expanded_model_dim=self._mlp_dim, # feedforward dim - rmsnorm_epsilon=self._rmsnorm_epsilon, - qknorm_epsilon=self._qknorm_epsilon, - tie_embeddings=self._tie_embeddings, - param_dtype=self._param_dtype, - compute_dtype=self._compute_dtype, - output_dtype=self._output_dtype, + dtype=jnp.float32, ) - self._mp_policy: jmp.Policy = cfg.mp_policy self._model = TransformerDo(cfg) input_shape = (1, self._seq_len) # For token IDs @@ -97,7 +66,8 @@ def init_model_fn( self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) params = jax_sharding_utils.replicate(params) - return params, None + model_state = None + return params, model_state def model_fn( self, @@ -111,12 +81,10 @@ def model_fn( ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode, rng, update_batch_norm, model_state, dropout_rate inputs = batch['inputs'] - params, inputs = self._mp_policy.cast_to_compute((params, inputs)) # Convert one-hot inputs to token IDs if needed if inputs.ndim == 3: # one-hot encoded inputs = jnp.argmax(inputs, axis=-1) logits = self._model.apply({'params': params}, inputs) - logits = self._mp_policy.cast_to_output(logits) return logits, None def loss_fn( @@ -171,17 +139,6 @@ def loss_fn( 'per_example': per_example_losses, } - @partial( - jax.jit, - static_argnums=(0,), - in_shardings=( - replicated_sharding, - batch_sharding, - replicated_sharding, - replicated_sharding, - ), - out_shardings=(replicated_sharding), - ) def _eval_batch( self, params: spec.ParameterContainer, diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py index 4c60198cc..edee8318c 100644 --- a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py @@ -26,24 +26,14 @@ class ModelConfig: qknorm_epsilon: float = 1e-6 use_residual_scaling: bool = True tie_embeddings: bool = True - compute_dtype: torch.dtype = torch.bfloat16 - param_dtype: torch.dtype = torch.float32 class MLP(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - multiple_of: int = 256, - dtype: torch.dtype = torch.float32, - ): + def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256): super().__init__() - hidden_dim = int( - multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - ) - self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False, dtype=dtype) - self.fc2 = nn.Linear(hidden_dim, dim, bias=False, dtype=dtype) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False) + self.fc2 = nn.Linear(hidden_dim, dim, bias=False) self.glu = nn.GLU(dim=2) nn.init.normal_(self.fc1.weight, std=0.02) nn.init.normal_(self.fc2.weight, std=0.02) @@ -98,12 +88,8 @@ def __init__(self, cfg: ModelConfig): self.n_heads = cfg.num_heads self.head_dim = cfg.model_dim // cfg.num_heads - self.w_qkv = nn.Linear( - cfg.model_dim, 3 * cfg.model_dim, bias=False, dtype=cfg.param_dtype - ) - self.w_out = nn.Linear( - cfg.model_dim, cfg.model_dim, bias=False, dtype=cfg.param_dtype - ) + self.w_qkv = nn.Linear(cfg.model_dim, 3 * cfg.model_dim, bias=False) + self.w_out = nn.Linear(cfg.model_dim, cfg.model_dim, bias=False) # Split into Q, K, V sections wq, wk, wv = torch.chunk(self.w_qkv.weight, 3, dim=0) for w in [wq, wk, wv]: @@ -113,9 +99,7 @@ def __init__(self, cfg: ModelConfig): self.eps = cfg.qknorm_epsilon # e.g., 1e-6 seq_len = cfg.seq_len attn_scale0 = math.log2(seq_len**2 - seq_len) - self.attn_scale = nn.Parameter( - torch.tensor(attn_scale0, dtype=cfg.param_dtype) - ) + self.attn_scale = nn.Parameter(torch.tensor(attn_scale0)) def forward(self, x, freqs_cis): bsz, seqlen, d = x.shape # (bsz, seqlen, d) @@ -158,18 +142,13 @@ class Block(nn.Module): def __init__(self, layer_id: int, cfg: ModelConfig): super().__init__() self.attn = Attention(cfg) - self.attn_norm = nn.RMSNorm( - cfg.model_dim, eps=cfg.rmsnorm_epsilon, dtype=cfg.param_dtype - ) + self.attn_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) self.mlp = MLP( dim=cfg.model_dim, hidden_dim=cfg.expanded_model_dim, multiple_of=cfg.multiple_of, - dtype=cfg.param_dtype, - ) - self.mlp_norm = nn.RMSNorm( - cfg.model_dim, eps=cfg.rmsnorm_epsilon, dtype=cfg.param_dtype ) + self.mlp_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) self.layer_id = layer_id def forward(self, x, freqs_cis): @@ -187,18 +166,12 @@ def __init__(self, cfg: ModelConfig): head_dim = cfg.model_dim // cfg.num_heads assert cfg.model_dim % cfg.num_heads == 0 - self.embed_tokens = nn.Embedding( - cfg.vocab_size, cfg.model_dim, dtype=cfg.param_dtype - ) + self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.model_dim) self.layers = nn.ModuleList( [Block(idx, cfg) for idx in range(cfg.num_layers)] ) - self.out_norm = nn.RMSNorm( - cfg.model_dim, eps=cfg.rmsnorm_epsilon, dtype=cfg.param_dtype - ) - self.lm_head = nn.Linear( - cfg.model_dim, cfg.vocab_size, bias=False, dtype=cfg.param_dtype - ) + self.out_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) + self.lm_head = nn.Linear(cfg.model_dim, cfg.vocab_size, bias=False) # Initialize freqs_cis on CPU first (more memory efficient) self.register_buffer( @@ -242,7 +215,6 @@ def forward(self, x, targets=None): for layer in self.layers: x = layer(x, freqs_cis) # (bsz, seqlen, dim) out = self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) - if targets is not None: loss = F.cross_entropy( out.view(-1, out.size(-1)), targets.view(-1), ignore_index=-100 @@ -260,43 +232,40 @@ def predict(self, x, k=1): Returns: Tuple of (input_ids, predicted_ids) """ - # Determine device type for autocast - device_type = 'cuda' if x.is_cuda else 'cpu' - with torch.autocast(device_type=device_type, dtype=self.cfg.compute_dtype): - # Store original input - original_input = x.clone() - generated_input = x.clone() + # Store original input + original_input = x.clone() + generated_input = x.clone() - # Generate k tokens autoregressively - for i in range(k): - # Get logits for the entire sequence - logits = self(generated_input) + # Generate k tokens autoregressively + for i in range(k): + # Get logits for the entire sequence + logits = self(generated_input) - # Get the logits for the last token in each sequence - next_token_logits = logits[:, -1, :] + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] - # Zero out the last token ID to prevent repetition - # This is a common issue - the model gets stuck repeating the last token - last_token_id = generated_input[:, -1] - next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) + # Zero out the last token ID to prevent repetition + # This is a common issue - the model gets stuck repeating the last token + last_token_id = generated_input[:, -1] + next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) - # Get the most likely token - next_token = torch.argmax(next_token_logits, dim=-1) + # Get the most likely token + next_token = torch.argmax(next_token_logits, dim=-1) - # Append the predicted token to the sequence - next_token = next_token.unsqueeze(1) # Add sequence dimension - generated_input = torch.cat([generated_input, next_token], dim=1) + # Append the predicted token to the sequence + next_token = next_token.unsqueeze(1) # Add sequence dimension + generated_input = torch.cat([generated_input, next_token], dim=1) - # For debugging, print predictions for the first item in the batch - print('\nPyTorch detailed prediction (first item in batch):') - predicted_sequence = generated_input[0, -k:].tolist() - print(f' Predicted token IDs: {predicted_sequence}') - for i, token_id in enumerate(predicted_sequence): - print(f' Step {i + 1}: Predicted token {token_id}') + # For debugging, print predictions for the first item in the batch + print('\nPyTorch detailed prediction (first item in batch):') + predicted_sequence = generated_input[0, -k:].tolist() + print(f' Predicted token IDs: {predicted_sequence}') + for i, token_id in enumerate(predicted_sequence): + print(f' Step {i + 1}: Predicted token {token_id}') - # Return all tokens, not just the last k - return original_input, generated_input[:, -k:] + # Return all tokens, not just the last k + return original_input, generated_input[:, -k:] def _init_weights(self, module): if isinstance(module, nn.Linear): @@ -349,8 +318,6 @@ def main(): # Instantiate the model model = Transformer(config) print(f'Model has {model.count_params():,} parameters.') - for n, p in model.named_parameters(): - print(f'{n}.dtype == {p.dtype}') # Create some random input data batch_size = 2 @@ -363,7 +330,6 @@ def main(): # Run a forward pass print(f'Running forward pass with input shape: {input_ids.shape}') logits = model(input_ids) - print(f'Output logits dtype: {logits.dtype}') print(f'Output logits shape: {logits.shape}') # Run prediction diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py index ed922f9c2..a25ca334a 100644 --- a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py @@ -19,25 +19,10 @@ USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() -# Dtype mapping from string to PyTorch dtype -DTYPE_MAP = { - 'float32': torch.float32, - 'float16': torch.float16, - 'bfloat16': torch.bfloat16, -} - class LmWorkload(BaseLmWorkload): """LM PyTorch workload.""" - @property - def _compute_dtype(self) -> torch.dtype: - return DTYPE_MAP[self._compute_dtype_str] - - @property - def _param_dtype(self) -> torch.dtype: - return DTYPE_MAP[self._param_dtype_str] - def init_model_fn( self, rng: spec.RandomState, @@ -55,14 +40,11 @@ def init_model_fn( vocab_size=self._vocab_size, seq_len=self._seq_len, model_dim=self._emb_dim, # Model dimension - expanded_model_dim=self._mlp_dim, # MLP expanded dim - num_layers=self._n_layers, - num_heads=self._n_heads, - rmsnorm_epsilon=self._rmsnorm_epsilon, - qknorm_epsilon=self._qknorm_epsilon, - tie_embeddings=self._tie_embeddings, - compute_dtype=self._compute_dtype, - param_dtype=self._param_dtype, + expanded_model_dim=self._mlp_dim, # MLP expansion factor + num_layers=self._n_layers, # Number of transformer layers + num_heads=self._n_heads, # Number of attention heads + rmsnorm_epsilon=1e-6, + tie_embeddings=True, ) self._model = Transformer(cfg) self._param_shapes = param_utils.pytorch_param_shapes(self._model) @@ -99,18 +81,13 @@ def model_fn( spec.ForwardPassMode.EVAL: torch.no_grad, spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } - - # Determine device type for autocast - device_type = 'cuda' if DEVICE.type == 'cuda' else 'cpu' - with contexts[mode](): - with torch.autocast(device_type=device_type, dtype=self._compute_dtype): - # Convert one-hot inputs to token IDs if needed - inputs = augmented_and_preprocessed_input_batch['inputs'] - if inputs.dim() == 3: # one-hot encoded - inputs = inputs.argmax(dim=-1) + # Convert one-hot inputs to token IDs if needed + inputs = augmented_and_preprocessed_input_batch['inputs'] + if inputs.dim() == 3: # one-hot encoded + inputs = inputs.argmax(dim=-1) - logits = model(inputs) + logits = model(inputs) return logits, None @@ -144,7 +121,7 @@ def _build_input_queue( batch['targets'], device=DEVICE, dtype=torch.int64 ), 'weights': torch.tensor( - batch['weights'], device=DEVICE, dtype=self._param_dtype + batch['weights'], device=DEVICE, dtype=torch.float32 ) if batch['weights'] is not None else None, @@ -180,35 +157,29 @@ def loss_fn( - 'n_valid_examples': Scalar tensor with the count of valid (non-masked) examples. - 'per_example': Tensor of shape [batch, length] with individual losses per example. """ - # Determine device type for autocast - device_type = 'cuda' if logits_batch.is_cuda else 'cpu' - - with torch.autocast(device_type=device_type, dtype=self._compute_dtype): - vocab_size = logits_batch.size(-1) - - # Compute cross-entropy loss with label smoothing - per_example_losses = torch.nn.functional.cross_entropy( - logits_batch.view(-1, vocab_size), - label_batch.view(-1), - reduction='none', - label_smoothing=label_smoothing, - ) - per_example_losses = per_example_losses.view_as(label_batch) - - # Apply weights if provided - if mask_batch is not None: - per_example_losses = per_example_losses * mask_batch - - # Calculate number of valid examples - n_valid_examples = ( - mask_batch.sum() - if mask_batch is not None - else torch.tensor( - label_batch.numel(), - dtype=self._param_dtype, - device=label_batch.device, - ) + vocab_size = logits_batch.size(-1) + + # Compute cross-entropy loss with label smoothing + per_example_losses = torch.nn.functional.cross_entropy( + logits_batch.view(-1, vocab_size), + label_batch.view(-1), + reduction='none', + label_smoothing=label_smoothing, + ) + per_example_losses = per_example_losses.view_as(label_batch) + + # Apply weights if provided + if mask_batch is not None: + per_example_losses = per_example_losses * mask_batch + + # Calculate number of valid examples + n_valid_examples = ( + mask_batch.sum() + if mask_batch is not None + else torch.tensor( + label_batch.numel(), dtype=torch.float32, device=label_batch.device ) + ) return { 'summed': per_example_losses.sum(), diff --git a/algoperf/workloads/finewebedu_lm/workload.py b/algoperf/workloads/finewebedu_lm/workload.py index 3abb9c138..e6e2e9ba5 100644 --- a/algoperf/workloads/finewebedu_lm/workload.py +++ b/algoperf/workloads/finewebedu_lm/workload.py @@ -27,16 +27,6 @@ class BaseLmWorkload(spec.Workload): _mlp_dim: int = 4096 warmup_factor: float = 0.1 - # Model configuration - _rmsnorm_epsilon: float = 1e-6 - _qknorm_epsilon: float = 1e-6 - _tie_embeddings: bool = True - - # Dtype configuration (as strings, to be converted by framework-specific subclasses) - _compute_dtype_str: str = 'bfloat16' - _param_dtype_str: str = 'float32' - _output_dtype_str: str = 'bfloat16' # Only used by JAX - def __init__(self) -> None: super().__init__() self._param_shapes = None @@ -95,11 +85,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 31_967 # 8.9 hours + return 31_967 # 8.9 hours @property def eval_period_time_sec(self) -> int: - return 2_571 # approximately 25 evals + return 2_571 # approximately 25 evals @property def step_hint(self) -> int: @@ -174,9 +164,9 @@ def _eval_model_on_split( eval_batch = next(self._eval_iters[split]) metrics = self._eval_batch(params, eval_batch, model_state, rng) for metric_name, metric_value in metrics.items(): - eval_metrics.update( - {metric_name: eval_metrics.get(metric_name, 0.0) + metric_value} - ) + if metric_name not in eval_metrics: + eval_metrics[metric_name] = 0.0 + eval_metrics[metric_name] += metric_value eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) eval_results['ppl'] = np.exp(eval_results['loss']).item() diff --git a/pyproject.toml b/pyproject.toml index e3d86df3d..006e7e5cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,7 +90,7 @@ librispeech_conformer = [ "pydub==0.25.1", ] wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.19.0"] -lm = ["transformers==4.26.0", "datasets==3.6.0"] +lm = ["transformers==4.26", "datasets==3.6.0"] # Frameworks jax_core_deps = [ @@ -99,7 +99,6 @@ jax_core_deps = [ "chex==0.1.86", "ml_dtypes==0.5.1", "protobuf==4.25.5", - "jmp>=0.0.4" ] jax_cpu = [ "jax==0.7.0",