Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source/tutorial/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ ray start --head
**A:** The following parameters may be helpful:

- For trainer, adjust `actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu` when `actor_rollout_ref.actor.use_dynamic_bsz=false`; adjust `actor_rollout_ref.actor.ppo_max_token_len_per_gpu` and `actor_rollout_ref.actor.ulysses_sequence_parallel_size` when `actor_rollout_ref.actor.use_dynamic_bsz=true`. Setting `actor_rollout_ref.actor.entropy_from_logits_with_chunking=true` may also help.
- For explorer, adjust `explorer.rollout_model.tensor_parallel_size`,
- For explorer, adjust `explorer.rollout_model.tensor_parallel_size`.


## Part 3: Debugging Methods [Coming Soon]
Expand Down
2 changes: 0 additions & 2 deletions examples/grpo_gsm8k_ruler/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ This example shows a toy implementation of ART's [RULER](https://art.openpipe.ai

RULER (Relative Universal LLM-Elicited Rewards) is a general-purpose reward function that uses an LLM-as-judge to rank the rollouts for a given task.

https://github.com/OpenPipe/ART/blob/main/src/art/rewards/ruler.py


## Configurations and Metrics

Expand Down
55 changes: 55 additions & 0 deletions examples/grpo_gsm8k_trainable_ruler/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Policy Model as Its Own Reward Model


Ref: ART's RULER; Kimi-k2.

This example shows an implementation of training a policy model as its own reward model with GRPO, inspired by ART's [RULER](https://art.openpipe.ai/fundamentals/ruler) and KIMI's [K2](https://moonshotai.github.io/Kimi-K2/).

We simulate a scenario where only a fraction (`PROBABILITY_GROUND_TRUTH_AVAILABLE = 0.2`) of tasks have ground-truth answers. We optimize two objectives jointly: one for response generation, the other for RULER-reward generation.


## Configurations and Metrics

The config file is located in [`gsm8k_ruler.yaml`](gsm8k_ruler.yaml).

Some key configs in this example are:

* `default_workflow_type`: set to `math_trainable_ruler_workflow`
* `std_threshold` for GRPO advantage: set to small value, filter out group of experiences with same rewards (e.g., when RULER fails to return valid scores, they are set to all zero)
* `sync_style`: use `dynamic_by_explorer`, due to filtering of experiences
* `train_batch_size`: set to 960; note that one explore step can generate more than 96 * 8 = 768 experiences
* `lr`: set to small value (2e-6) for stability, as rewards can be noisy



Some important metrics to pay attention to are:

* `reward`: reward calculated by rule or by RULER
* `gold_reward`: sum of `accuracy_reward` and `format_reward`, rule-based calculation with ground truth
* `judge_success`: whether RULER successfully returns a valid score (a coarse estimation, mix up two types of experiences)
* `reward_for_judger`: reward for the LLM working as a RULER reward model, calculated by mean absolute error (MAE) distance from gold scores
* `eval_accuracy`: accuracy on the evaluation set (ultimate metric for success of RL)


## Results

We show the results below:


![reward](../../docs/sphinx_doc/assets/gsm8k_trainable_ruler_reward.png)

![gold_reward](../../docs/sphinx_doc/assets/gsm8k_trainable_ruler_gold_reward.png)

![judge_success](../../docs/sphinx_doc/assets/gsm8k_trainable_ruler_judge_success.png)

![reward_for_judger](../../docs/sphinx_doc/assets/gsm8k_trainable_ruler_reward_for_judger.png)

![eval_accuracy](../../docs/sphinx_doc/assets/gsm8k_trainable_ruler_eval_accuracy.png)


You may compare the above results with [the RULER example](../../examples/grpo_gsm8k_ruler/README.md) with Qwen2.5-32B-Instruct as LLM judge (`auxiliary_models`).


## Potential improvements

As this is a toy example, we may consider some further improvements, such as automatically balancing the number of samples for two objectives, or their loss weights. We also plan to test out this approach in broader scenarios, e.g., cross-domain transfer of the model's critic capability.
78 changes: 78 additions & 0 deletions examples/grpo_gsm8k_trainable_ruler/gsm8k_ruler.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
project: "Trinity-RFT-gsm8k-trainable-ruler"
name: "qwen2.5-1.5B-gsm8k-trainable-ruler"
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
algorithm:
algorithm_type: grpo
advantage_fn_args:
std_threshold: 0.0001 # effectively zero
repeat_times: 8
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
max_prompt_tokens: 12288
max_response_tokens: 12288
max_model_len: 16000 # slightly smaller than ppo_max_token_len_per_gpu (16384)
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_epochs: 1
batch_size: 96
train_batch_size: 960
explorer_input:
taskset:
name: gsm8k
storage_type: file
path: 'openai/gsm8k'
subset_name: 'main'
split: 'train'
format:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
temperature: 1.0
eval_tasksets:
- name: gsm8k-eval
storage_type: file
path: 'openai/gsm8k'
subset_name: 'main'
split: 'test'
format:
prompt_key: 'question'
response_key: 'answer'
default_workflow_type: 'math_trainable_ruler_workflow'
trainer_input:
experience_buffer:
name: gsm8k_buffer
storage_type: queue
explorer:
eval_interval: 10
runner_num: 32
rollout_model:
engine_type: vllm_async
engine_num: 4
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
synchronizer:
sync_style: dynamic_by_explorer
sync_method: 'nccl'
sync_interval: 5
sync_timeout: 3600
trainer:
save_interval: 100
trainer_config:
actor_rollout_ref:
model:
use_remove_padding: true
actor:
use_dynamic_bsz: true
ppo_max_token_len_per_gpu: 16384
ulysses_sequence_parallel_size: 1
optim:
lr: 2e-6
ref:
log_prob_use_dynamic_bsz: ${trainer.trainer_config.actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${trainer.trainer_config.actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${trainer.trainer_config.actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
4 changes: 2 additions & 2 deletions tests/common/experience_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def test_eid_properties(self):
eid = EID()
eid2 = EID()
self.assertIsInstance(eid.suffix, str)
self.assertEqual(eid.batch, 0)
self.assertEqual(eid.task, 0)
self.assertEqual(eid.batch, "")
self.assertEqual(eid.task, "")
self.assertEqual(eid.run, 0)
self.assertEqual(eid.step, 0)
self.assertNotEqual(eid.uid, eid2.uid)
Expand Down
6 changes: 3 additions & 3 deletions trinity/common/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pickle
import uuid
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Dict, List, Literal, Optional, Union

import torch
from datasets import Dataset
Expand All @@ -22,10 +22,10 @@ class EID:
# TODO: do we need to add project/name here to make it unique across different projects?
# Batch number, e.g., the explorer step num
# Automatically set by the workflow runner
batch: int = 0
batch: Union[int, str] = ""
# Task number, e.g., the task sequence in the batch, the first task in the batch has task=0
# Automatically set by the workflow runner
task: int = 0
task: Union[int, str] = ""
# Run id, e.g., the first run in the task has run=0
# User should set this field in custom workflows when creating experiences
run: int = 0
Expand Down
2 changes: 2 additions & 0 deletions trinity/common/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .eval_workflow import MathEvalWorkflow
from .math_rm_workflow import MathRMWorkflow
from .math_ruler_workflow import MathRULERWorkflow
from .math_trainable_ruler_workflow import MathTrainableRULERWorkflow
from .simple_mm_workflow import SimpleMMWorkflow
from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task, Workflow

Expand All @@ -34,5 +35,6 @@
"AgentScopeReactV2MathWorkflow",
"EmailSearchWorkflow",
"MathRULERWorkflow",
"MathTrainableRULERWorkflow",
"SimpleMMWorkflow",
]
7 changes: 6 additions & 1 deletion trinity/common/workflows/math_ruler_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def get_ruler_scores(
Please assign a score within the range [0, 1] for each of them, reflecting how well they solve the question.
You may compare them against each other and think step by step before returning your final scores, but keep your reasoning process brief and concise when possible.

Conclude your response with a list of scores, in the following format: [score for solution 1, score for solution 2, ..., score for solution {num_responses + 1}]
Conclude your response with a list of scores, in the following format: [score for solution 1, score for solution 2, ..., score for solution {num_responses}]
"""

# Step 2: invoke judger LLM
Expand All @@ -145,6 +145,11 @@ def get_ruler_scores(
try:
scores = ast.literal_eval(lst_as_str)
scores = [max(0.0, min(1.0, score)) for score in scores] # clip to range [0, 1]
if len(scores) != num_responses:
self.logger.warning(
"The length of list in judger response does not match num_responses."
)
return False, [0.0 for _ in range(num_responses)]
return True, scores
except Exception:
self.logger.warning(
Expand Down
Loading