Minimal GRPO (Group Relative Policy Optimization) training loop for LLMs. Built from scratch with vLLM for fast generation and DeepSpeed ZeRO-2 for memory-efficient training.
┌─────────────┐ HTTP ┌─────────────────┐
│ vLLM Server │◄────────────►│ Orchestrator │
│ (generation)│ /generate │ (main loop) │
└─────────────┘ /reload │ │
│ ┌────────────┐ │
│ │ Trainer │ │
│ │ (DeepSpeed)│ │
│ └────────────┘ │
│ ┌────────────┐ │
│ │ Rewarder │ │
│ └────────────┘ │
└─────────────────┘
Per training step:
- Orchestrator sends prompts (repeated
group_sizetimes each) to vLLM - vLLM generates completions + old log probs
- Rewarder scores each completion
- Trainer computes group-relative advantages, runs K gradient epochs with PPO-style clipped loss
- Updated weights saved to disk, vLLM hot-reloads them
rl-max/
├── src/
│ ├── vllm_client.py # vLLM FastAPI server for rollout generation
│ ├── trainer.py # GRPO trainer with DeepSpeed ZeRO-2
│ ├── orchestrator.py # Main training loop (generate → reward → train → sync)
│ ├── rewards.py # Pluggable reward registry + built-in rewards
│ └── data.py # Dataset loading (WIP)
├── config/
│ ├── ds_config.json # DeepSpeed config (ZeRO-2, CPU optimizer offload)
│ └── config.yaml # Training config (WIP)
├── checkpoints/
│ └── policy/ # HuggingFace model checkpoint (shared by vLLM + trainer)
├── run.py # Quick API test script
└── requirements.txt
# Install dependencies
pip install -r requirements.txt
# Download model (Qwen 2.5 1.5B Instruct)
python3 -c "
from huggingface_hub import snapshot_download
snapshot_download('Qwen/Qwen2.5-1.5B-Instruct', local_dir='checkpoints/policy')
"# Single GPU
python3 -m src.vllm_client --model-path /workspace/rl-max/checkpoints/policy
# With tensor parallelism
python3 -m src.vllm_client --model-path /workspace/rl-max/checkpoints/policy --tp 2# Use a different GPU than vLLM
CUDA_VISIBLE_DEVICES=1 deepspeed --num_gpus 1 -m src.orchestratorIf running on a single GPU (vLLM + trainer share it):
CUDA_VISIBLE_DEVICES=0 deepspeed --num_gpus 1 -m src.orchestratorpython3 run.pyFastAPI server wrapping vLLM for batched generation. Endpoints:
POST /generate— batch generate completions with log probsPOST /reload_weights— hot-reload model weights from disk after trainingGET /health— health check
GRPO trainer with DeepSpeed ZeRO-2. Key methods:
prepare_inputs()— tokenize prompt+completion, build attention/completion maskscompute_token_log_probs()— forward pass → shifted log-softmax → per-token log probssequence_log_probs()— sum token log probs over completion maskcompute_advantages()— group-relative normalization:(r - mean) / (std + ε)grpo_loss()— PPO-style clipped objective at sequence levelsave_checkpoint()— save HF-format weights for vLLM reload
Pluggable reward system with auto-registration:
LengthReward— rewards longer completions (normalized to [0, 1])ThinkFormatReward— rewards<think>...</think>structure- Add custom rewards by subclassing
BaseRewardand decorating with@register_reward
Main training loop:
- Repeat each prompt
group_sizetimes → send to vLLM - Score completions with
Rewarder - Run
num_epochsgradient updates per batch (old log probs frozen, new log probs updated each epoch) - Save checkpoint + reload vLLM weights
For each training step:
1. Sample prompts P₁, P₂, ..., Pₙ
2. For each Pᵢ, generate K completions → group of K samples
3. Score each completion with reward function → rᵢⱼ
4. Compute group-relative advantage: Aᵢⱼ = (rᵢⱼ - mean(rᵢ)) / (std(rᵢ) + ε)
5. For each epoch k = 1..K:
a. Forward pass → new log probs
b. ratio = exp(new_log_prob - old_log_prob)
c. loss = -mean(min(ratio × A, clip(ratio, 1±ε) × A))
d. Backward + optimizer step
6. Save weights → reload vLLM
- Sequence-level loss: ratios are computed at the sequence level (sum of token log probs), not per-token. Can be unstable for long completions.
- No KL penalty: Has been removed and since its rlvr dont need it prob need it to prevent reward hacking
- No micro-batching: entire batch goes through one forward pass. Large batches or long sequences may OOM.
- Disk-based weight sync: checkpoint save + vLLM reload is slow. Production systems use shared memory.
- Hardcoded prompts: uses 4 fixed prompts. Needs proper dataset integration.
- Single-node only: designed for single-node multi-GPU. No multi-node support.
- KL divergence penalty against reference policy (removed as not used now )
- Token-level clipped loss (more stable)
- Dataset loader in
data.py - ZeRO-3 checkpoint saving