← Concept library

Foundations

Rejection Sampling as RL

Rejection sampling fine-tuning filters a model's own outputs by correctness and trains on the survivors, achieving a policy-improvement step that is mathematically equivalent to one round of RL but without an explicit optimiser loop.

intermediate · 8 min read

When Yuan et al. published their rejection sampling fine-tuning (RFT) results in 2023, they showed a LLaMA-7B model improving from 35.9% to 49.3% on GSM8K using no human labels, no reward model, and no RL library. The mechanism was almost embarrassingly simple: sample multiple solutions from the current model, keep the ones with correct answers, and train on those. No PPO, no KL coefficient tuning, no critic network. Yet the downstream improvement was real and measurable.

This prompts the question: is RFT actually RL, or just clever data augmentation? The answer is that it is both, and understanding exactly why resolves a lot of confusion about what RL for language models is fundamentally doing.

The Core Equivalence

Standard RL for language models maximises an expected reward under the current policy. Write the policy as pi_theta, the prompt distribution as p(x), and the reward function as r(y | x). The RL objective is:

J(theta) = E_{x ~ p(x), y ~ pi_theta(x)} [ r(y | x) ]

A policy gradient step nudges theta to increase the probability of high-reward completions and decrease the probability of low-reward ones. The gradient estimator weights each sampled completion by its reward (or by its advantage relative to a baseline).

Rejection sampling fine-tuning does something structurally identical. For each prompt x, sample a group of G completions. Discard those where r = 0 (wrong answer). Train on the survivors via supervised fine-tuning. This is precisely a policy gradient step where the advantage is binary: +1 for correct completions, 0 for incorrect ones, with no gradient flowing from the discarded outputs at all.

Formally, the SFT cross-entropy loss on filtered samples is:

L_RFT = - E_{x} [ (1/|Y+|) * sum_{y in Y+} log pi_theta(y | x) ]

where Y+ is the set of accepted (correct) completions. This is a Monte Carlo estimate of the policy gradient under a binary reward, with the zero-reward samples silently dropped. The gradient is:

grad L_RFT ≈ - E[ r(y|x) * grad log pi_theta(y | x) ]

which is the REINFORCE gradient. RFT is REINFORCE with binary rewards and no baseline subtraction, run for one iteration, then re-fitted as a supervised objective. The RL framing is not a post-hoc rationalisation; it is the same computation wearing different clothes.

What Changes With Iteration

One round of RFT on a fixed sample set is a single policy improvement step. The improvement saturates quickly because the training set was generated from the old policy; the model will soon fit those examples and the marginal gain from further epochs diminishes.

The natural fix is to iterate: after training, regenerate the sample set from the updated model and filter again. This is called iterative or online RFT, and it converts a one-shot supervised procedure into a genuine on-policy RL loop:

repeat:
  1. Sample G completions per prompt from pi_current
  2. Filter to Y+ using the reward function
  3. Fine-tune pi_current -> pi_next on Y+
  4. Set pi_current = pi_next

Each iteration is a policy improvement step under the REINFORCE gradient. The loop is functionally equivalent to a simplified PPO without a clipped surrogate objective and without a critic, but with the same on-policy data-collection requirement. The Llama 2 paper (Touvron et al., 2023) reports using exactly this structure during its RLHF post-training: iterative rejection sampling fine-tuning ran alongside PPO, with the two methods alternating depending on which stage of training was being optimised.

Property Single-pass RFT Iterative RFT PPO
On-policy data No (fixed sample) Yes (refreshed each round) Yes
Explicit critic No No Yes
KL regularisation Implicit (SFT loss) Implicit Explicit coefficient
Memory cost Low Low High (4 model copies)
Theoretical convergence One step Iterates to local optimum Iterates with trust region

The implicit KL regularisation deserves a note. Because RFT fine-tunes with standard cross-entropy, the model is being pulled towards the filtered distribution. It is not explicitly penalised for drifting from the reference policy. In practice this means RFT can drift further per iteration than PPO would allow, which can be either a feature (faster change) or a bug (instability and distribution collapse) depending on how aggressively you iterate.

Best-of-N as the Inference-Time Counterpart

Rejection sampling at inference time is called best-of-N (or generate-then-rank). Sample N completions, score each with a reward model or verifier, return the highest-scoring one. No gradient, no parameter update. This is RFT's cousin: the same filter-on-reward idea applied to deployment rather than training.

Gao et al. (2022) studied how proxy reward and gold reward diverge as N scales. For best-of-N, proxy reward grows as O(log N) while gold reward eventually peaks and falls - the same reward over-optimisation shape seen in PPO, just with the optimiser replaced by sampling volume. This tells you something important: the mechanism that causes reward over-optimisation is not specific to gradient-based RL. It is a property of any procedure that searches for high-reward outputs under an imperfect scorer. RFT training is subject to the same pressure; iterating many rounds against a learned reward model will eventually produce the same pathologies as over-optimised PPO.

Where the KL Term Went

Standard KL-regularised RL adds an explicit penalty:

J_KL(theta) = E[ r(y|x) ] - beta * KL[ pi_theta || pi_ref ]

RFT drops this term entirely. In practice, three things substitute for it:

  1. Short training duration. Fine-tuning for a handful of epochs on the filtered set does not move the policy far from initialisation.
  2. Early stopping. Practitioners typically stop when validation reward saturates, before the model drifts into degenerate solutions.
  3. Merging rounds. Some pipelines mix filtered data with the original SFT data in each round, acting as a soft anchor to the reference distribution.

None of these is as principled as an explicit KL penalty. When RFT is iterated aggressively without these safeguards, the model can drift into overfit solutions: it learns to produce outputs that match the training-time filter but generalises poorly, exactly analogous to PPO reward hacking.

When It Falls Down

No gradient signal for hard problems. If the model's pass rate on a problem is near zero, it produces no correct samples, so no gradient flows. The training set skews towards problems the model already partially solves, and hard problems are ignored. PPO with a dense reward or GRPO with within-group normalisation at least produce a signal when all samples are wrong (zero advantage for all, which is a stable no-update). RFT simply produces no data.

Reward model as filter introduces all the standard biases. When the correctness criterion is a learned reward model rather than a verifier, the filter inherits the reward model's blind spots. The training set is shaped by the reward model's preferences, so iterative RFT against a biased reward model is iterative reward hacking. The advantage of using a verifier (a deterministic checker) is that this problem disappears; the disadvantage is that verifiers only exist for narrow task types.

Sample efficiency degrades at scale. To get a useful training signal on a hard problem, you may need to sample dozens of completions before one is correct. At 70B parameters, generating 64 rollouts per prompt per iteration is expensive. PPO amortises this by extracting a gradient from every rollout, not just the correct ones. RFT discards most of its compute.

Implicit KL drift. Without an explicit regularisation term, iterative RFT can shift the output distribution far from the reference model across many rounds. This is hard to detect from training metrics alone and tends to manifest as reduced diversity in later iterations, eventually collapsing to a handful of solution templates.

One-step instability. A single RFT round with a very small correct-sample rate (say, 1 in 64 completions) places enormous weight on a tiny set of survivors. Those survivors may be atypical solutions that happen to match the filter but are not representative of good reasoning. The model can overfit to idiosyncratic correct solutions rather than learning general principles.

Further Reading

  • Yuan et al. (2023). Scaling Relationship on Learning Mathematical Reasoning with Large Language Models. arxiv.org/abs/2308.01825 - Introduces rejection sampling fine-tuning for maths and provides the scaling analysis of RFT as a training method.
  • Touvron et al. (2023). Llama 2: Open Foundation and Fine-Tuned Chat Models. arxiv.org/abs/2307.09288 - Details the alternating RFT and PPO training regime used in Llama 2-Chat post-training.
  • Gao, Schulman & Hilton (2022). Scaling Laws for Reward Model Overoptimization. arxiv.org/abs/2210.10760 - Empirically characterises over-optimisation under both RL and best-of-N selection, establishing the unified picture of reward gaming.
  • Shao et al. (2024). DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models. arxiv.org/abs/2402.03300 - Introduces GRPO, which adds within-group advantage normalisation to address RFT's zero-gradient problem on hard prompts.
Sign in to save and react.
Share Copied