TRL documentation

A2PO

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v1.7.0).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

A2PO

model badge

TRL supports A*-PO (Optimal Advantage Regression) as described in the paper Accelerating RL for LLM Reasoning with Optimal Advantage Regression by Kianté Brantley, Mingyu Chen, Zhaolin Gao, Jason D. Lee, Wen Sun, Wenhao Zhan, and Xuezhou Zhang.

The abstract from the paper is the following:

Reinforcement learning (RL) has emerged as a powerful tool for fine-tuning large language models (LLMs) to improve complex reasoning abilities. However, state-of-the-art policy optimization methods often suffer from high computational overhead and memory consumption, primarily due to the need for multiple generations per prompt and the reliance on critic networks or advantage estimates of the current policy. In this paper, we propose A*-PO, a novel two-stage policy optimization framework that directly approximates the optimal advantage function and enables efficient training of LLMs for reasoning tasks. In the first stage, we leverage offline sampling from a reference policy to estimate the optimal value function V*, eliminating the need for costly online value estimation. In the second stage, we perform on-policy updates using a simple least-squares regression loss with only a single generation per prompt. Theoretically, we establish performance guarantees and prove that the KL-regularized RL objective can be optimized without requiring complex exploration strategies. Empirically, A*-PO achieves competitive performance across a wide range of mathematical reasoning benchmarks, while reducing training time by up to 2× and peak memory usage by over 30% compared to PPO, GRPO, and REBEL.

Usage

A*-PO assumes a binary, verifiable reward (r ∈ {0, 1}) and runs in two stages:

  1. Offline value estimation. Before training, num_value_samples completions are sampled from the reference policy for every prompt and scored with reward_funcs. The optimal value V*(x) = β₁·log(mean_i exp(r(x, yᵢ)/β₁)) is estimated and cached per prompt.
  2. On-policy regression. During training, a single completion is generated per prompt from the current policy. The loss is the squared error between the implicit reward β₂·log(π(y|x)/π_ref(y|x)) and the optimal advantage r(x, y) − V*(x).
from trl.experimental.a2po import A2POConfig, A2POTrainer

# A*-PO assumes a binary, verifiable reward in {0, 1}.
def reward_correct(completions, ground_truth, **kwargs):
    return [float(completion.strip() == truth) for completion, truth in zip(completions, ground_truth)]

training_args = A2POConfig(
    output_dir="Qwen2.5-0.5B-A2PO",
    num_value_samples=8,  # Stage 1: samples per prompt from the reference policy to estimate V*
    beta1=0.5,  # Stage 1: KL temperature for the V* estimate
    beta2=1e-3,  # Stage 2: KL temperature for the regression target
)
trainer = A2POTrainer(
    model="Qwen/Qwen2.5-0.5B",
    reward_funcs=reward_correct,
    args=training_args,
    train_dataset=...,
)
trainer.train()

Because V* is estimated entirely from reference-policy samples, A*-PO cannot exceed the reference policy’s Pass@K. The official implementation can be found at ZhaolinGao/A-PO.

A2POTrainer

class trl.experimental.a2po.A2POTrainer

< >

( model: transformers.modeling_utils.PreTrainedModel | strreward_funcs: collections.abc.Callable[..., list[float]] | list[collections.abc.Callable[..., list[float]]]args: trl.experimental.a2po.a2po_config.A2POConfig | None = Nonetrain_dataset = Noneeval_dataset = Noneprocessing_class: transformers.tokenization_utils_base.PreTrainedTokenizerBase | None = Nonecallbacks = Noneoptimizers = (None, None) )

Parameters

  • model (PreTrainedModel or str) — Model to be trained, or a model identifier (string) passed to from_pretrained.
  • reward_funcs (Callable or list[Callable]) — Reward function(s). Each takes prompts and completions (plus dataset columns as keyword arguments) and returns a list of float rewards. When multiple are provided, their weighted sum (see A2POConfig.reward_weights) is the scalar reward r, which A*-PO assumes to be binary (in {0, 1}).
  • args (A2POConfig, optional) — Configuration for this trainer. If None, a default configuration is used.
  • train_dataset (Dataset, optional) — Training dataset. Must contain a "prompt" column.
  • eval_dataset (Dataset, optional) — Evaluation dataset.
  • processing_class (PreTrainedTokenizerBase, optional) — Processing class used to process the data. If None, it is loaded from the model’s name with from_pretrained.
  • callbacks (list[~transformers.TrainerCallback], optional) — List of callbacks to customize the training loop.
  • optimizers (tuple[~torch.optim.Optimizer, ~torch.optim.lr_scheduler.LambdaLR], optional, defaults to (None, None)) — Tuple containing the optimizer and the learning rate scheduler.

Trainer for the A*-PO (Optimal Advantage Regression) method, introduced in Accelerating RL for LLM Reasoning with Optimal Advantage Regression.

A*-PO runs in two stages:

  1. Offline value estimation. Before training, num_value_samples completions are sampled from the reference policy for every training prompt and scored with reward_funcs. The optimal value is estimated as V*(x) = beta1 * log(mean_i exp(r(x, y_i) / beta1)) and cached per prompt.
  2. On-policy regression. During training, a single completion is generated per prompt from the current policy. The loss is the squared error between the implicit reward beta2 * log(pi(y|x) / pi_ref(y|x)) and the optimal advantage estimate r(x, y) - V*(x).

train

< >

( *args**kwargs )

save_model

< >

( output_dir: str | None = None_internal_call: bool = False )

Will save the model, so you can reload it using from_pretrained().

Will only save from the main process.

push_to_hub

< >

( commit_message: str | None = 'End of training'blocking: bool = Truetoken: str | None = Nonerevision: str | None = None**kwargs )

Parameters

  • commit_message (str, optional, defaults to "End of training") — Message to commit while pushing.
  • blocking (bool, optional, defaults to True) — Whether the function should return only when the git push has finished.
  • token (str, optional, defaults to None) — Token with write permission to overwrite Trainer’s original args.
  • revision (str, optional) — The git revision to commit from. Defaults to the head of the “main” branch.
  • kwargs (dict[str, Any], optional) — Additional keyword arguments passed along to ~Trainer.create_model_card.

Upload self.model and self.processing_class to the 🤗 model hub on the repo self.args.hub_model_id.

A2POConfig

class trl.experimental.a2po.A2POConfig

< >

( output_dir: str | None = Noneper_device_train_batch_size: int = 8num_train_epochs: float = 3.0max_steps: int = -1learning_rate: float = 5e-05lr_scheduler_type: transformers.trainer_utils.SchedulerType | str = 'linear'lr_scheduler_kwargs: dict | str | None = Nonewarmup_steps: float = 0optim: transformers.training_args.OptimizerNames | str = 'adamw_torch_fused'optim_args: str | None = Noneweight_decay: float = 0.0adam_beta1: float = 0.9adam_beta2: float = 0.999adam_epsilon: float = 1e-08optim_target_modules: None | str | list[str] = Nonegradient_accumulation_steps: int = 1average_tokens_across_devices: bool = Truemax_grad_norm: float = 1.0label_smoothing_factor: float = 0.0bf16: bool | None = Nonefp16: bool = Falsebf16_full_eval: bool = Falsefp16_full_eval: bool = Falsetf32: bool | None = Nonegradient_checkpointing: bool = Truegradient_checkpointing_kwargs: dict[str, typing.Any] | str | None = Nonetorch_compile: bool = Falsetorch_compile_backend: str | None = Nonetorch_compile_mode: str | None = Noneuse_liger_kernel: bool = Falseliger_kernel_config: dict[str, bool] | None = Noneuse_cache: bool = Falseneftune_noise_alpha: float | None = Nonetorch_empty_cache_steps: int | None = Noneauto_find_batch_size: bool = Falselogging_strategy: transformers.trainer_utils.IntervalStrategy | str = 'steps'logging_steps: float = 10logging_first_step: bool = Falselog_on_each_node: bool = Truelogging_nan_inf_filter: bool = Trueinclude_num_input_tokens_seen: str | bool = 'no'log_level: str = 'passive'log_level_replica: str = 'warning'disable_tqdm: bool | None = Nonereport_to: None | str | list[str] = 'none'run_name: str | None = Noneproject: str = 'huggingface'trackio_space_id: str | None = Nonetrackio_bucket_id: str | None = Nonetrackio_static_space_id: typing.Union[str, NoneType, typing.Literal[False]] = Noneeval_strategy: transformers.trainer_utils.IntervalStrategy | str = 'no'eval_steps: float | None = Noneeval_delay: float = 0per_device_eval_batch_size: int = 8prediction_loss_only: bool = Falseeval_on_start: bool = Falseeval_do_concat_batches: bool = Trueeval_use_gather_object: bool = Falseeval_accumulation_steps: int | None = Noneinclude_for_metrics: list = <factory>batch_eval_metrics: bool = Falsesave_only_model: bool = Falsesave_strategy: transformers.trainer_utils.SaveStrategy | str = 'steps'save_steps: float = 500save_on_each_node: bool = Falsesave_total_limit: int | None = Noneenable_jit_checkpoint: bool = Falsepush_to_hub: bool = Falsehub_token: str | None = Nonehub_private_repo: bool | None = Nonehub_model_id: str | None = Nonehub_strategy: transformers.trainer_utils.HubStrategy | str = 'every_save'hub_always_push: bool = Falsehub_revision: str | None = Noneload_best_model_at_end: bool = Falsemetric_for_best_model: str | None = Nonegreater_is_better: bool | None = Noneignore_data_skip: bool = Falserestore_callback_states_from_checkpoint: bool = Falsefull_determinism: bool = Falseseed: int = 42data_seed: int | None = Noneuse_cpu: bool = Falseaccelerator_config: dict | str | None = Noneparallelism_config: accelerate.parallelism_config.ParallelismConfig | None = Nonedataloader_drop_last: bool = Falsedataloader_num_workers: int = 0dataloader_pin_memory: bool = Truedataloader_persistent_workers: bool = Falsedataloader_prefetch_factor: int | None = Noneremove_unused_columns: bool = Falselabel_names: list[str] | None = Nonetrain_sampling_strategy: str = 'random'length_column_name: str = 'length'ddp_find_unused_parameters: bool | None = Noneddp_bucket_cap_mb: int | None = Noneddp_broadcast_buffers: bool | None = Noneddp_static_graph: bool | None = Noneddp_backend: str | None = Noneddp_timeout: int = 1800fsdp: str | None = Nonefsdp_config: dict[str, typing.Any] | str | None = Nonedeepspeed: dict | str | None = Nonedebug: str | list[transformers.debug_utils.DebugOption] = ''skip_memory_metrics: bool = Truedo_train: bool = Falsedo_eval: bool = Falsedo_predict: bool = Falseresume_from_checkpoint: str | None = Nonewarmup_ratio: float | None = Nonelogging_dir: str | None = Nonelocal_rank: int = -1model_init_kwargs: dict | None = Nonetrust_remote_code: bool = Falsemax_prompt_length: int | None = 512max_completion_length: int | None = 256temperature: float = 1.0top_p: float = 1.0top_k: int | None = Nonenum_value_samples: int = 8beta1: float = 0.5filter_all_incorrect: bool = Truebeta2: float = 0.001reward_weights: list[float] | None = None )

Parameters that control the model and reference model

  • model_init_kwargs (dict[str, Any], optional) — Keyword arguments for from_pretrained, used when the model argument of the A2POTrainer is provided as a string.
  • trust_remote_code (bool, optional, defaults to False) — Whether to allow loading models and tokenizers that ship custom Python code from the Hub. Forwarded to from_pretrained and from_pretrained.

Parameters that control the data preprocessing

  • remove_unused_columns (bool, optional, defaults to False) — Whether to only keep the column "prompt" in the dataset. If you use a custom reward function that requires any column other than "prompts" and "completions", you should keep this to False.

Parameters that control generation

  • max_prompt_length (int or None, optional, defaults to 512) — Maximum length of the prompt. If the prompt is longer than this, it is left-truncated.
  • max_completion_length (int or None, optional, defaults to 256) — Maximum length of the generated completion.
  • temperature (float, optional, defaults to 1.0) — Sampling temperature, used in both Stage 1 and Stage 2 generation.
  • top_p (float, optional, defaults to 1.0) — Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to 1.0 to consider all tokens.
  • top_k (int or None, optional) — Number of highest-probability vocabulary tokens to keep. If None, top-k filtering is disabled.

Parameters that control Stage 1 (offline optimal value estimation)

  • num_value_samples (int, optional, defaults to 8) — Number of samples drawn from the reference policy per prompt to estimate V*.
  • beta1 (float, optional, defaults to 0.5) — KL temperature used to estimate V* in Stage 1.
  • filter_all_incorrect (bool, optional, defaults to True) — Whether to drop prompts for which all reference samples are incorrect.

Parameters that control Stage 2 (on-policy regression)

  • beta2 (float, optional, defaults to 1e-3) — KL temperature used in the Stage 2 regression target.
  • reward_weights (list[float], optional) — Weights for each reward function. Must match the number of reward functions. If None, all rewards are weighted equally with weight 1.0.

Configuration class for the A2POTrainer.

This class includes only the parameters that are specific to A2PO training. For a full list of training arguments, please refer to the TrainingArguments documentation. Note that default values in this class may differ from those in TrainingArguments.

Update on GitHub