Skip to content

A unified JAX framework for memory-augmented reinforcement learning with RNNs, SSMs, Transformers and more

License

Notifications You must be signed in to change notification settings

memory-rl/memorax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

552 Commits
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

๐Ÿง  Memorax

A unified reinforcement learning framework featuring memory-augmented algorithms and POMDP environment implementations. This repository provides modular components for building, configuring, and running a variety of RL algorithms on classic and memory-intensive environments.

PyPI version Python 3.11+ License Documentation

โœจ Features

  • ๐Ÿค– Memory-RL: JAX implementations of DQN, PPO (Discrete & Continuous), SAC (Discrete & Continuous), PQN, IPPO, R2D2, and their memory-augmented variants with burn-in support for recurrent networks.
  • ๐Ÿ“ฆ Pure JAX Episode Buffer: A fully JAX-native episode buffer implementation enabling efficient storage and sampling of complete episodes for recurrent training, with support for Prioritized Experience Replay.
  • ๐Ÿ” Sequence Models: LSTM/GRU (via Flax), sLSTM/mLSTM, FFM/SHM, S5/LRU/Mamba/MinGRU, plus Self-Attention and Linear Attention blocks. GPT-2/GTrXL/xLSTM-style architectures are composed from these primitives (see examples/architectures).
  • ๐Ÿงฌ Networks: MLP, CNN, and ViT encoders with support for RoPE and ALiBi positional embeddings, and Mixture of Experts (MoE) for horizontal scaling.
  • ๐ŸŽฎ Environments: Support for Gymnax, PopJym, PopGym Arcade, Navix, Craftax, Brax, MuJoCo, gxm, XMiniGrid, and JaxMARL.
  • ๐Ÿ“Š Logging & Sweeps: Support for a CLI Dashboard, Weights & Biases, TensorboardX, and Neptune.
  • ๐Ÿ”ง Easy to Extend: Clear directory structure for adding new networks, algorithms, or environments.

๐Ÿ“ฅ Installation

Install Memorax using pip:

pip install memorax

Or using uv:

uv add memorax

Optionally you can add support for CUDA with:

pip install memorax[cuda]

Optional: Set up Weights & Biases for logging by logging in:

wandb login

๐Ÿš€ Quick Start

Run a default DQN experiment on CartPole:

uv run examples/dqn_cartpole.py

๐Ÿ’ป Usage

import jax
import optax
from memorax.algorithms import PPO, PPOConfig
from memorax.environments import environment
from memorax.networks import (
    MLP, FFN, ALiBi, FeatureExtractor, GatedResidual, Network,
    PreNorm, SegmentRecurrence, SelfAttention, Stack, heads,
)

env, env_params = environment.make("gymnax::CartPole-v1")

cfg = PPOConfig(
    name="PPO-GTrXL",
    num_envs=8,
    num_eval_envs=16,
    num_steps=128,
    gamma=0.99,
    gae_lambda=0.95,
    num_minibatches=4,
    update_epochs=4,
    normalize_advantage=True,
    clip_coef=0.2,
    clip_vloss=True,
    ent_coef=0.01,
    vf_coef=0.5,
)

features, num_heads, num_layers = 64, 4, 2
feature_extractor = FeatureExtractor(observation_extractor=MLP(features=(features,)))
attention = GatedResidual(PreNorm(SegmentRecurrence(
    SelfAttention(features, num_heads, context_length=128, positional_embedding=ALiBi(num_heads)),
    memory_length=64, features=features,
)))
ffn = GatedResidual(PreNorm(FFN(features=features, expansion_factor=4)))
torso = Stack(blocks=(attention, ffn) * num_layers)

actor_network = Network(feature_extractor, torso, heads.Categorical(env.action_space(env_params).n))
critic_network = Network(feature_extractor, torso, heads.VNetwork())
optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(3e-4))

agent = PPO(cfg, env, env_params, actor_network, critic_network, optimizer, optimizer)
key, state = agent.init(jax.random.key(0))
key, state, transitions = agent.train(key, state, num_steps=10_000)

๐Ÿ“‚ Project Structure

memorax/
โ”œโ”€ examples/          # Small runnable scripts (e.g., DQN CartPole)
โ”œโ”€ memorax/
   โ”œโ”€ algorithms/     # DQN, PPO, SAC, PQN, ...
   โ”œโ”€ networks/       # MLP, CNN, ViT, RNN, heads, ...
   โ”œโ”€ environments/   # Gymnax / PopGym / Brax / ...
   โ”œโ”€ buffers/        # Custom flashbax buffers
   โ”œโ”€ loggers/        # CLI, WandB, TensorBoardX integrations
   โ””โ”€ utils/

๐Ÿงฉ JAX POMDP Ecosystem

Memorax is designed to work alongside a growing suite of JAX-native tools focused on partial observability and memory. These projects provide the foundational architectures and benchmarks for modern memory-augmented RL:

๐Ÿง  Architectures & Infrastructure

  • Memax: A library for efficient sequence and recurrent modeling in JAX. It provides unified interfaces for fast recurrent state resets and associative scans, serving as a powerful primitive for building memory architectures.
  • Flashbax: The library powering Memorax's buffer system. It provides high-performance, JAX-native experience replay buffers optimized for sequence storage and prioritized sampling.
  • Gymnax: The standard for JAX-native RL environments. Memorax provides seamless wrappers to run recurrent agents on these vectorized tasks.

๐ŸŽฎ POMDP Benchmarks & Environments

  • PopGym Arcade: A JAX-native suite of "pixel-perfect" POMDP environments. It features Atari-style games specifically designed to test long-term memory with hardware-accelerated rendering.
  • PopJym: A fast, JAX-native implementation of the POPGym benchmark suite, providing a variety of classic POMDP tasks optimized for massive vectorization.
  • Navix: Accelerated MiniGrid-style environments. These are excellent for testing spatial reasoning and navigation in partially observable grid worlds.
  • XLand-MiniGrid: A high-throughput meta-RL environment suite that provides massive task diversity for testing agent generalization in POMDPs.

๐Ÿ“„ License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.

๐Ÿ“š Citation

If you use Memory-RL for your work, please cite:

@software{memorax2025github,
  title   = {Memorax: A Unified Framework for Memory-Augmented Reinforcement Learning},
  author  = {Noah Farr},
  year    = {2025},
  url     = {https://github.com/memory-rl/memorax}
}

๐Ÿ™ Acknowledgments

Special thanks to @huterguier for the valuable discussions and advice on the API design.

About

A unified JAX framework for memory-augmented reinforcement learning with RNNs, SSMs, Transformers and more

Resources

License

Contributing

Stars

Watchers

Forks

Contributors 3

  •  
  •  
  •  

Languages