Skip to content

rakeshbs/MLXDeepQLearning

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MLX Deep Q-Learning

Deep Q-Learning playground built around custom game environments and MLX-based DQN variants.

Current environments:

  • Flappy Bird
  • Breakout

Trained Agents

Flappy Bird Breakout
Flappy Bird Breakout

Current training setup:

  • DQN
  • DoubleDQN
  • Prioritized Experience Replay
  • ParallelRunner with distributed actor/learner layout

Requirements

This project is written for Python 3.10 and uses MLX, so it is intended for Apple Silicon machines.

Install dependencies:

python -m venv .venv
source .venv/bin/activate
pip install -U pip
pip install mlx
pip install -r requirements.txt

requirements.txt currently contains the non-MLX runtime dependencies:

  • pygame
  • numpy

For video recording during test:

  • opencv-python (pip install opencv-python)

Environments

Training Experiments

Breakout

Train:

python -m experiments.breakout.cnn_dqn

Test latest checkpoint:

python -m experiments.breakout.cnn_dqn --test

Test best checkpoint (by Avg100):

python -m experiments.breakout.cnn_dqn --test --best

Test best single-episode score checkpoint:

python -m experiments.breakout.cnn_dqn --test --best-score

Test with full 5-life game (no terminal on life loss):

python -m experiments.breakout.cnn_dqn --test --best --full-game

Test with epsilon-greedy evaluation (e.g. ε=0.05):

python -m experiments.breakout.cnn_dqn --test --best --epsilon=0.05

Test without rendering (faster, for benchmarking):

python -m experiments.breakout.cnn_dqn --test --best --no-render

Run a fixed number of test episodes:

python -m experiments.breakout.cnn_dqn --test --best --episodes=100

Flags can be combined freely:

python -m experiments.breakout.cnn_dqn --test --best --full-game --no-render --episodes=100 --epsilon=0.05

Flappy Bird

State-vector DQN:

python -m experiments.flappy.dqn
python -m experiments.flappy.dqn --test
python -m experiments.flappy.dqn --test --best

State-vector Double DQN:

python -m experiments.flappy.double_dqn
python -m experiments.flappy.double_dqn --test
python -m experiments.flappy.double_dqn --test --best

Record the best-scoring episode as a video to runs/:

python -m experiments.flappy.double_dqn --test --best --record

Pixel-based CNN Double DQN:

python -m experiments.flappy.cnn_dqn
python -m experiments.flappy.cnn_dqn --test
python -m experiments.flappy.cnn_dqn --test --best

Manual Play

python play.py --game breakout
python play.py --game flappy

See each environment's README for controls.

Checkpoints

Checkpoints are written under checkpoints/<experiment_name>/ as:

  • latest.npz / latest.json
  • best.npz / best.json
  • best_score.npz / best_score.json

For the parallel runner:

  • latest is always the most recent learner state
  • best is selected using the best rolling Avg100, not the best single episode
  • best_score is selected by the highest single-episode score

Video Recording

Pass --record during test to capture the best-scoring episode as an MP4:

python -m experiments.flappy.double_dqn --test --best --record

About

MLX Deep Q-Learning Experiments

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages