GitHub Repositories
Jax RL Implementations
Reinforcement Learning in JAX.
CleanRL
High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)
Dopamax
Reinforcement learning in pure JAX.
Earl
Reinforcement learning with Equinox
FISOR
[ICLR 2024] The official implementation of 'Safe Offline Reinforcement Learning with Feasibility-Guided Diffusion Model'
Implicit Q-Learning
This repository contains the official implementation of Offline Reinforcement Learning with Implicit Q-Learning
Jax-Baseline
Jax-Baseline is a Reinforcement Learning implementation using JAX and Flax/Haiku libraries, mirroring the functionality of Stable-Baselines.
JAX-Corl
Clean single-file implementation of offline RL algorithms in JAX
JaxCQL
Conservative Q learning in Jax
JaxGCRL
Goal-Conditioned Reinforcement Learning with JAX
JaxIRL
JaxIRL ontains JAX implementation of algorithms for inverse reinforcement learning.
JaxMARL
Multi-Agent Reinforcement Learning with JAX
JaxRL
JAX (Flax) implementation of algorithms for Deep Reinforcement Learning with continuous action spaces.
jym
JAX implementation of RL algorithms and vectorized environments
Mava
A research-friendly codebase for fast experimentation of multi-agent reinforcement learning in JAX.
pax
Scalable Opponent Shaping Experiments in JAX
popjaxrl
Benchmarking RL for POMDPs in Pure JAX [Code for 'Structured State Space Models for In-Context Reinforcement Learning' (NeurIPS 2023)]
purejaxql
Simple single-file baselines for Q-Learning in pure-GPU setting
PureJaxRL
Really Fast End-to-End Jax RL Implementations.
Reinforced-lib
Reinforcement learning library.
Rejax
Hardware-Accelerated Reinforcement Learning Algorithms in pure Jax!
RL Basics
Simple single file implementation of Deep Reinforcement Learning algorithms.
RL-Flax
Various reinforcement learning algorithms written in Jax + Flax.
RLax
Building blocks for implementing RL agents.
rlbase_stable
This is a codebase that implements simple reinforcement learning algorithms in JAX.
SBX
Stable Baselines Jax (SB3 + Jax) RL algorithms
sebulba
🪐 The Sebulba architecture to scale reinforcement learning on Cloud TPUs in JAX
Skrl
Modular reinforcement learning library (on PyTorch and JAX) with support for NVIDIA Isaac Gym, Omniverse Isaac Gym and Isaac Lab
Stoix
A research-friendly codebase for fast experimentation of single-agent reinforcement learning in JAX.
UniFloral
Unified Implementations of Offline Reinforcement Learning Algorithms
Wsrl
JAX implementation of WSRL and RL baselines | ICLR 2025
Network Libraries
Neural networks and scientific computing in JAX.
Training
Repositories tailored to optimize and enhance the training process in JAX.
Optimization Libraries
Gradient processing and optimization in JAX.
Optax
Optax is a gradient processing and optimization library for JAX.
RL Environments in JAX
Reinforcement Learning environments written in JAX.
Brax
Massively parallel rigidbody physics simulation on accelerator hardware.
Evojax
EvoJAX is a scalable, general purpose, hardware-accelerated neuroevolution toolkit.
Evorl
EvoRL is a fully GPU-accelerated framework for Evolutionary Reinforcement Learning, implemented with JAX
Gymnax
RL Environments in JAX.
JaxMARL
Multi-Agent Reinforcement Learning with JAX
Jumanji
A diverse suite of scalable reinforcement learning environments in JAX.
Kinetix
Reinforcement learning on general 2D physics environments in JAX.
Navix
Accelerated minigrid environments with JAX
Pgx
Vectorized RL game environments in JAX
popjym
POPGym Library in JAX
XLand-MiniGrid
JAX-accelerated Meta-Reinforcement Learning Environments Inspired by XLand and MiniGrid
Tools
Additional repositories that may be of interest.
Chex
Chex is a library of utilities for helping to write reliable JAX code.
Distrax
Distrax is a lightweight library of probability distributions and bijectors. It acts as a JAX-native reimplementation of a subset of TensorFlow Probability (TFP).
Haliax
Named Tensors for Legible Deep Learning in JAX.
JaxTyping
Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays.
Orbax
Orbax provides common checkpointing and persistence utilities for JAX users.
Treescope
An interactive HTML pretty-printer for machine learning research in IPython notebooks.