Flexible residual connections for PyTorch with a clean, composable API.
Build complex residual architectures without boilerplate. torchresidual provides
Record and Apply modules that let you create skip connections of any depth,
with automatic shape handling and learnable mixing coefficients.
📖 Quick Start | 📚 Full Documentation | 💡 Examples | ❓ FAQ
Standard PyTorch:
class ResidualBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = nn.Linear(dim, dim)
self.norm = nn.LayerNorm(dim)
def forward(self, x):
residual = x
x = self.linear(x)
x = F.relu(x)
x = self.norm(x)
return x + residual # Manual residualWith torchresidual:
block = ResidualSequential(
Record(name="input"),
nn.Linear(64, 64),
nn.ReLU(),
nn.LayerNorm(64),
Apply(record_name="input"), # Automatic residual
)Benefits:
- No custom
forward()methods - Multiple skip connections with named records
- Automatic projection when dimensions change
- Five residual operations (add, concat, multiply, gated, highway)
- Learnable mixing coefficients
- Works with LSTMs, attention, and any
nn.Module
pip install torchresidualRequirements: Python ≥3.9, PyTorch ≥1.9
New to torchresidual? See the Quick Start Guide for a 5-minute tutorial.
import torch
import torch.nn as nn
from torchresidual import ResidualSequential, Record, Apply
block = ResidualSequential(
Record(name="input"),
nn.Linear(64, 64),
nn.ReLU(),
nn.LayerNorm(64),
Apply(record_name="input", operation="add"),
)
x = torch.randn(8, 64)
out = block(x) # Shape: [8, 64]block = ResidualSequential(
Record(name="input", need_projection=True),
nn.Linear(64, 32),
nn.ReLU(),
Record(name="mid"),
nn.Linear(32, 64),
Apply(record_name="input"), # Long skip with projection
nn.LayerNorm(64),
nn.Linear(64, 32),
Apply(record_name="mid"), # Short skip
)from torchresidual import LearnableAlpha
block = ResidualSequential(
Record(name="r"),
nn.Linear(64, 64),
Apply(
record_name="r",
operation="gated",
alpha=LearnableAlpha(0.3, min_value=0.0, max_value=1.0)
),
)
# Alpha is learned during training
optimizer = torch.optim.Adam(block.parameters(), lr=1e-3)# Input: [batch, 64] → Output: [batch, 128]
block = ResidualSequential(
Record(name="r", need_projection=True), # Enables auto-projection
nn.Linear(64, 128),
nn.ReLU(),
Apply(record_name="r"), # Automatically projects 64→128
)from torchresidual import RecurrentWrapper
block = ResidualSequential(
Record(name="r"),
RecurrentWrapper(
nn.LSTM(32, 32, num_layers=2, batch_first=True),
return_hidden=False
),
Apply(record_name="r"),
)
x = torch.randn(4, 10, 32) # [batch, seq_len, features]
out = block(x)Drop-in replacement for nn.Sequential with residual connection support.
Example:
block = ResidualSequential(
nn.Linear(64, 64),
Record(),
nn.ReLU(),
Apply(),
)Saves the current tensor for later use in a residual connection.
Args:
need_projection(bool): IfTrue,Applywill create a linear projection when shapes don't matchname(str, optional): Label for this record point. Auto-assigned ifNone.
Example:
Record(name="input", need_projection=True)Applies a residual connection using a previously recorded tensor.
Args:
operation(str): One of"add","concat","multiply","gated","highway"record_name(str, optional): WhichRecordto use. IfNone, uses most recent.alpha(float or LearnableAlpha): Scaling factor for residual branch
Operations:
| Operation | Formula | Use case |
|---|---|---|
add |
x + α·r |
Standard ResNet-style |
concat |
cat([x, r], dim=-1) |
DenseNet-style |
multiply |
x·(1 + α·r) |
Multiplicative skip |
gated |
(1-α)·x + α·r |
Learnable interpolation |
highway |
T·x + C·r |
Highway Networks |
Example:
Apply(operation="gated", record_name="input", alpha=0.5)Learnable scalar parameter constrained to [min_value, max_value].
Args:
initial_value(float): Starting valuemin_value(float): Lower bound (inclusive)max_value(float): Upper bound (inclusive)use_log_space(bool, optional): Force log or linear parameterization. Auto-detected ifNone.
Example:
alpha = LearnableAlpha(0.5, min_value=0.0, max_value=1.0)
x = x + alpha() * residual # alpha() returns constrained valueRecurrentWrapper(module, return_hidden=False)
Wraps LSTM/GRU modules for seamless integration with ResidualSequential.
Args:
module(nn.Module): The recurrent module (e.g.,nn.LSTM)return_hidden(bool): IfTrue, returns(output, hidden)tuple
Example:
RecurrentWrapper(nn.LSTM(64, 64, batch_first=True), return_hidden=False)# Multi-head attention with residual and layer norm
block = ResidualSequential(
Record(name="input"),
nn.MultiheadAttention(embed_dim=256, num_heads=8),
Apply(record_name="input"),
nn.LayerNorm(256),
Record(name="attn_out"),
nn.Linear(256, 1024),
nn.ReLU(),
nn.Linear(1024, 256),
Apply(record_name="attn_out"),
nn.LayerNorm(256),
)inner_block = ResidualSequential(
Record(),
nn.Linear(64, 64),
nn.ReLU(),
Apply(),
)
outer_block = ResidualSequential(
Record(),
inner_block,
nn.Linear(64, 64),
Apply(),
)from collections import OrderedDict
encoder = ResidualSequential(OrderedDict([
('record_input', Record(need_projection=True, name="input")),
('conv1', nn.Conv1d(64, 128, kernel_size=3, padding=1)),
('relu1', nn.ReLU()),
('record_mid', Record(name="mid")),
('conv2', nn.Conv1d(128, 128, kernel_size=3, padding=1)),
('relu2', nn.ReLU()),
('apply_long', Apply(record_name="input")),
('norm', nn.BatchNorm1d(128)),
('conv3', nn.Conv1d(128, 64, kernel_size=1)),
('apply_short', Apply(record_name="mid", operation="concat")),
]))| Environment | Status | Notes |
|---|---|---|
| Single GPU training | ✅ | Full support |
| CPU training | ✅ | Full support |
nn.DataParallel |
✅ | Thread-safe via threading.local() |
DistributedDataParallel |
✅ | Process-safe, recommended for multi-GPU |
| Multi-threaded inference | ✅ | Safe for Flask/FastAPI servers |
| Jupyter notebooks | ✅ | Full support |
torch.jit.script |
❌ | Planned for v1.1 |
| ONNX export | ❌ | Planned for v1.1 |
torchresidual uses threading.local() for context management, making it safe for:
nn.DataParallel(multiple GPU threads)- Multi-threaded inference servers
- Concurrent requests in production
See docs/DESIGN.md for implementation details.
Traditional approaches store a parent reference in Apply, creating circular references:
ResidualSequential → Apply → ResidualSequential # Breaks pickle/deepcopy
torchresidual uses threading.local() to avoid this:
- ✅ No circular references
- ✅ Works with
pickle,torch.save,deepcopy - ✅ Thread-safe for
nn.DataParallel - ✅ Clean module hierarchy
LearnableAlpha uses tanh (not sigmoid) for bounded parameters:
- Better gradient flow near boundaries
- Symmetric around midpoint
- Stable training dynamics
For ranges spanning orders of magnitude (e.g., 1e-4 to 1e-1), linear space
poorly explores the lower end. Log space provides uniform coverage:
alpha = LearnableAlpha(0.01, min_value=1e-4, max_value=1.0)
# Automatically uses log space (ratio > 100)See examples/ directory:
basic_usage.py- Core conceptsadvanced_usage.py- Advanced conceptslstm_residual.py- Recurrent networks
Contributions welcome! Please:
- Fork the repository
- Create a feature branch
- Add tests for new functionality
- Ensure
pytestandmypypass - Submit a pull request
Development setup:
git clone https://github.com/v-garzon/torchresidual.git
cd torchresidual
pip install -e ".[dev]"
pytest tests/
mypy torchresidual/If you use torchresidual in your research, please cite:
@software{torchresidual2026,
author = {Garzón, Víctor},
title = {torchresidual: Flexible residual connections for PyTorch},
year = {2026},
url = {https://github.com/v-garzon/torchresidual}
}MIT License - see LICENSE for details.
See CHANGELOG.md for version history.