This repository contains a comprehensive from-scratch implementation of decoder-only transformer models, featuring advanced techniques like Grouped Query Attention (GQA), Rotary Positional Embeddings (RoPE), and RMSNorm. The implementation demonstrates end-to-end training on conversational datasets like LMSYS-Chat-1M with modern optimization techniques including mixed precision training and WandB monitoring.
Designed for researchers, developers, and practitioners interested in understanding and implementing transformer architectures for conversational AI applications.
- ποΈ Modular Architecture - Clean, well-structured transformer implementation
- β‘ Advanced Attention - Grouped Query Attention (GQA) for efficient inference
- π Rotary Embeddings - RoPE for better positional understanding
- π RMSNorm - Stable normalization technique
- π Efficient Training - Mixed precision training with gradient clipping
- π Monitoring - WandB integration for experiment tracking
- βοΈ Flexible CLI - Command-line interface for easy configuration
- π¦ Dataset Support - Compatible with HuggingFace datasets
- Python 3.8+
- PyTorch 2.0+ (with CUDA support for GPU training)
- Hugging Face Transformers and Datasets libraries
- WandB (optional, for experiment tracking)
-
Clone the repository:
git clone https://github.com/dino65-dev/Transformers.git cd Transformers -
Create virtual environment:
python -m venv env source env/bin/activate # On Windows: env\Scripts\activate
-
Install dependencies:
pip install torch transformers datasets wandb tokenizers
Train models with the comprehensive CLI interface:
python train/train.py \
--d_model 768 \
--num_layers 10 \
--num_heads 12 \
--kv_heads 4 \
--d_ff 2048 \
--dropout 0.1 \
--seq_len 2048 \
--epochs 5 \
--batch_size 6 \
--lr 3e-4 \
--checkpoint_dir "./checkpoints" \
--dataset_name "lmsys/lmsys-chat-1m" \
--dataset_subset 10000| Parameter | Description | Default |
|---|---|---|
--d_model |
Model dimension | 768 |
--num_layers |
Transformer layers | 10 |
--num_heads |
Attention heads | 12 |
--kv_heads |
Key-value heads (GQA) | 4 |
--d_ff |
Feed-forward dimension | 2048 |
--dropout |
Dropout rate | 0.1 |
--seq_len |
Sequence length | 2048 |
--epochs |
Training epochs | 5 |
--batch_size |
Batch size | 6 |
--lr |
Learning rate | 3e-4 |
# Using different datasets
python train/train.py --dataset_name "EleutherAI/pile" --dataset_config "all"
python train/train.py --dataset_name "tatsu-lab/alpaca" --dataset_split "train"# Enable WandB logging
python train/train.py --project_name "my-transformer" --run_name "experiment-1"
# Disable WandB
python train/train.py --no_wandbGenerate text using trained models:
python test/generate.py \
--checkpoint ./checkpoints/best_model.pt \
--prompt "<user> Hello! How are you?" \
--max_new_tokens 100 \
--device cudaFor models with custom hyperparameters:
python test/generate.py \
--checkpoint ./checkpoints/best_model.pt \
--prompt "Once upon a time" \
--d_model 768 --num_layers 10 --num_heads 12 --kv_heads 4.
βββ transformer/ # Core model implementation
β βββ build_transformer.py # Model builder and configuration
β βββ transformer_.py # Main transformer model
β βββ decoder.py # Decoder wrapper
β βββ decoder_block.py # Transformer decoder blocks
β βββ gqa.py # Grouped Query Attention implementation
β βββ rope.py # Rotary Position Embeddings
β βββ rope_helper.py # RoPE utility functions
β βββ rms_norm.py # RMSNorm implementation
β βββ ff_block.py # Feed-forward blocks
β βββ residual_connection.py # Residual connections
β βββ input_embeddings.py # Token embeddings
β βββ positional_encoding.py # Positional encodings
β βββ projection_layer.py # Output projection
βββ train/ # Training utilities
β βββ train.py # Main training script
β βββ dataset_define.py # Dataset processing
β βββ tokenizer.py # Tokenizer setup
β βββ coustom_tokenizer.py # Custom tokenizer implementation
β βββ save_checkpoint.py # Checkpoint management
βββ test/ # Inference and testing
β βββ generate.py # Text generation script
βββ model_train.ipynb # Training notebook
βββ test-ng.ipynb # Testing and generation notebook
βββ auto_checkpoint_epoch_1_step_48403.pt # Sample checkpoint
βββ Screenshot 2025-08-01 214849.png # Training loss visualization
- Grouped Query Attention (GQA): Efficient attention mechanism reducing memory usage
- Rotary Position Embeddings (RoPE): Superior positional encoding for long sequences
- RMSNorm: Stable and efficient normalization
- Feed-Forward Networks: SwiGLU activation with configurable dimensions
- Residual Connections: Skip connections for stable training
- Model Dimensions:
d_model=768,d_ff=2048 - Architecture:
num_layers=10,num_heads=12,kv_heads=4 - Context Length:
seq_len=2048 - Tokenizer: GPT-2 tokenizer with special tokens:
[PAD],<user>,<assistant> - Optimization: AdamW with mixed precision training
- Hardware: Ola Krutim AI Pod A100 40GB
- Training Time: ~6.5 hours for full run
- Dataset: LMSYS-Chat-1M (subset)
Training loss curve showing convergence from ~10.3 to ~4.0-5.0 range over epochs
Typical Training Progression:
- Epoch 1: Average Loss ~6.0 (starting from ~10.3)
- Later Epochs: Convergence to ~4.0-5.0 range
- Monitoring: Real-time tracking via WandB
- β Automatic checkpointing every 2 hours
- β Best model saving based on validation loss
- β Mixed precision training (FP16)
- β Gradient clipping for stability
- β GPU memory optimization
- β Flexible dataset loading from HuggingFace
- β Configurable sampling parameters
- β Custom prompt templates
- β Efficient inference with KV caching
- β Multi-turn conversation support
Contributions welcome! Areas for enhancement:
- π― Model Performance: Improve generation quality and coherence
- π Evaluation: Add perplexity and other standard metrics
- π§ Scalability: Support for larger model configurations
- π Documentation: Enhanced code documentation and examples
- β‘ Optimization: Further training and inference speedups
This project is licensed under the MIT License - see the LICENSE file for details.
- Inspired by modern transformer architectures and best practices
- Built with PyTorch and Hugging Face ecosystem
- Training infrastructure powered by Ola Krutim AI Pod
Ready to train your own transformer? π
For questions, feature requests, or collaboration opportunities, please open an issue!
