Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/docker/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ services:
- RAY_ADDRESS=auto
- TRINITY_CHECKPOINT_ROOT_DIR=/mnt/checkpoints
- TRINITY_TASKSET_PATH=/mnt/data
- TRINITY_EVAL_TASKSET_PATH=/mnt/data
- TRINITY_SFT_DATASET_PATH=/mnt/data
- TRINITY_MODEL_PATH=/mnt/models/Qwen3-0.6B
- TRINITY_API_MODEL_PATH=/mnt/models/Qwen3-1.7B
Expand Down
221 changes: 221 additions & 0 deletions examples/rec_gsm8k/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# Example: REC on GSM8k dataset

This example shows the usage of REC on the [GSM8k dataset](https://huggingface.co/datasets/openai/gsm8k).

For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_reasoning_basic.md).

The config file is located in [`gsm8k.yaml`](gsm8k.yaml).

# Group-relative REINFORCE Families
This folder provides **example configurations** for running different group-relative REINFORCE families within Trinity-RFT.

It includes three major families:

- **REC family** (clipping + importance sampling)
- **REP family** (regularization-based variants)
- **RED family** (data-distribution shaping strategies)

We also provide baseline implementations such as **Vanilla REINFORCE** and **GRPO**.

All algorithms are instantiated through modular YAML configs for easy reproduction and extension.

# Summary Table 📝

| Family | Variants | Key Idea |
| ------------- | ----------------------------------------------- | ----------------------------------- |
| **Baselines** | REINFORCE, GRPO | Standard references |
| **REC** | OneSide-NoIS, OneSide-IS, TwoSide-IS, Ring-NoIS | Clipping + importance sampling |
| **REP** | AsymRE, OPMD | Regularization |
| **RED** | Drop, Weight | Data-distribution shaping |



# Instantiations

## Baselines

### REINFORCE
Vanilla REINFORCE with group mean as baseline.

```
algorithm:
algorithm_type: rec
policy_loss_fn_args:
epsilon_low: 0.2
epsilon_high: 0.2
clip_mode: "none" # no clipping
weight: "none" # uniform weighting for samples
temp: 1.0
regularizer: "none" # no regularizer
regularizer_coef: 0.0
advantage_fn_args:
std_normalize: false
```

### GRPO
GRPO implemented with zero KL regularizer. Regularization can be enabled via `kl_loss_fn` and `kl_loss_fn_args`.

```
algorithm:
algorithm_type: rec
policy_loss_fn_args:
epsilon_low: 0.2
epsilon_high: 0.2
clip_mode: "one-side"
weight: "importance_sampling"
temp: 1.0
regularizer: "none"
regularizer_coef: 0.0
advantage_fn_args:
std_normalize: true
kl_loss_fn: 'k2'
kl_loss_fn_args:
kl_coef: 0.0

```

## REC family
Variants of clipping and importance-sampling strategies.
- REC-OneSide-NoIS
- REC-OneSide-IS
- REC-TwoSide-IS
- REC-Ring-NoIS

### REC-OneSide-NoIS

```
algorithm:
algorithm_type: rec
policy_loss_fn_args:
epsilon_low: 0.2
epsilon_high: 0.2
clip_mode: "one-side"
weight: "none"
temp: 1.0
regularizer: "none"
regularizer_coef: 0.0
advantage_fn_args:
std_normalize: false
```

### REC-OneSide-IS

```
algorithm:
algorithm_type: rec
policy_loss_fn_args:
epsilon_low: 0.2
epsilon_high: 0.2
clip_mode: "one-side"
weight: "importance_sampling"
temp: 1.0
regularizer: "none"
regularizer_coef: 0.0
advantage_fn_args:
std_normalize: false
```

### REC-TwoSide-IS

```
algorithm:
algorithm_type: rec
policy_loss_fn_args:
epsilon_low: 0.2
epsilon_high: 0.2
clip_mode: "two-side"
weight: "importance_sampling"
temp: 1.0
regularizer: "none"
regularizer_coef: 0.0
advantage_fn_args:
std_normalize: false
```
### REC-Ring-NoIS

```
algorithm:
algorithm_type: rec
policy_loss_fn_args:
epsilon_low: 0.2
epsilon_high: 0.2
epsilon_low_prime: 0.6
epsilon_high_prime: 2.0
clip_mode: "ring"
weight: "none"
temp: 1.0
regularizer: "none"
regularizer_coef: 0.0
advantage_fn_args:
std_normalize: false
```

## REP family

Regularization-based algorithms.
- AsymRE (forward KL regularization)
- Kimi’s OPMD (k2 regularizer)

### AsymRE

```
algorithm:
algorithm_type: rec
policy_loss_fn_args:
clip_mode: "none"
weight: "none"
temp: 1.0
regularizer: "forward-kl"
regularizer_coef: 0.1
advantage_fn_args:
std_normalize: false
```


### Kimi's OPMD

```
algorithm:
algorithm_type: rec
policy_loss_fn_args:
clip_mode: "none"
weight: "none"
regularizer: "k2"
regularizer_coef: 0.1
advantage_fn_args:
std_normalize: false
```

## RED family
Data-distribution shaping variants.
- RED-Drop (drop extra negative examples to balance the positive examples v.s. negative examples)
- RED-Weight (advantage-weighting strategy)

### RED-Drop

```
algorithm:
algorithm_type: rec
policy_loss_fn_args:
clip_mode: "none"
weight: "none"
regularizer: "none"
advantage_fn_args:
std_normalize: false
drop: "balance"
```


### RED-Weight

```
algorithm:
algorithm_type: rec
policy_loss_fn_args:
clip_mode: "none"
weight: "advantage"
regularizer: "none"
temp: 1.0
advantage_fn_args:
std_normalize: false
```
85 changes: 85 additions & 0 deletions examples/rec_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Configuration file for the REC GSM8k project.
project: "Trinity-RFT-GSM8K"
name: rec_gsm8k
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
mode: both
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-3B-Instruct}
max_response_tokens: 1024
max_model_len: 1280
algorithm:
algorithm_type: rec
policy_loss_fn_args:
epsilon_low: 0.2
epsilon_high: 0.2
clip_mode: "none"
weight: "none"
temp: 1.0
regularizer: "none"
regularizer_coef: 0.0
advantage_fn_args:
std_normalize: false
repeat_times: 8
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_steps: 100
batch_size: 96
explorer_input:
taskset:
name: gsm8k
storage_type: file
path: ${oc.env:TRINITY_TASKSET_PATH}
split: train
format:
prompt_key: question
response_key: answer
rollout_args:
temperature: 1.0
eval_tasksets:
- name: gsm8k-eval
storage_type: file
path: ${oc.env:TRINITY_EVAL_TASKSET_PATH}
split: test
format:
prompt_key: question
response_key: answer
default_workflow_type: math_workflow
trainer_input:
experience_buffer:
name: gsm8k_buffer
storage_type: queue
explorer:
eval_interval: 20
runner_num: 64
rollout_model:
engine_type: vllm_async
engine_num: 4
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
synchronizer:
sync_method: nccl
sync_interval: 20
sync_timeout: 1200
sync_offset: 0
trainer:
trainer_type: verl
save_interval: 100
trainer_config:
actor_rollout_ref:
model:
use_remove_padding: true
actor:
use_dynamic_bsz: true
ppo_max_token_len_per_gpu: 16384
ulysses_sequence_parallel_size: 1
optim:
lr: 1e-6
ref:
log_prob_use_dynamic_bsz: ${trainer.trainer_config.actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${trainer.trainer_config.actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${trainer.trainer_config.actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
Loading