Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 68 additions & 12 deletions docs/sphinx_doc/source/tutorial/example_async_mode.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
# Asynchronous RFT

This example shows how to run RFT in a fully asynchronous mode with the GRPO algorithm, Qwen2.5-1.5B-Instruct model and GSM8K dataset.
This example demonstrates how to run RFT in fully asynchronous mode using the GRPO algorithm, Qwen2.5-1.5B-Instruct model, and GSM8K dataset.

Trinity-RFT supports an asynchronous mode by running the trainer and explorer in separate processes.
Trinity-RFT supports Asynchronous RFT by running the trainer and explorer in separate processes.

For this purpose, we prepare two main config files: [`explorer.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/async_gsm8k/explorer.yaml) and [`trainer.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/async_gsm8k/trainer.yaml).
The main difference between them is that in `explorer.yaml` we set `mode` as `explore`, while in `trainer.yaml` we set `mode` as `train`.
For this purpose, we provide two main configuration files: [`explorer.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/async_gsm8k/explorer.yaml) and [`trainer.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/async_gsm8k/trainer.yaml).
The primary difference between them is that in `explorer.yaml` we set `mode` as `explore`, while in `trainer.yaml` we set `mode` as `train`.
The model weights of the explorer and trainer are synchronized once every `sync_interval * batch_size` tasks.

Suppose we have a node of 8 GPUs; we use 4 GPUs for the trainer and 4 GPUs for the explorer.
Some important setups of `explorer.yaml` are listed in the following:
Assuming we have a node with 8 GPUs, we allocate 4 GPUs for the trainer and 4 GPUs for the explorer. Key configurations in `explorer.yaml` are as follows:

```yaml
# explorer.yaml
project: <project_name>
name: <experiment_name>
mode: explore
Expand All @@ -26,7 +26,7 @@ cluster:
gpu_per_node: 4
buffer:
total_epochs: 1
batch_size: 96
batch_size: 64
explorer_input:
taskset:
name: gsm8k
Expand All @@ -45,7 +45,6 @@ buffer:
storage_type: queue
path: 'sqlite:///gsm8k.db'
explorer:
eval_interval: 10
runner_num: 32
rollout_model:
engine_type: vllm_async
Expand All @@ -57,9 +56,10 @@ trainer:
trainer_config_path: examples/async_gsm8k/verl_config.yaml
```

Some important setups of `trainer.yaml` are listed in the following:
Key configurations in `trainer.yaml` are as follows:

```yaml
# trainer.yaml
project: <project_name>
name: <experiment_name>
mode: train
Expand All @@ -74,7 +74,7 @@ cluster:
gpu_per_node: 4
buffer:
total_epochs: 1
batch_size: 96
batch_size: 64
explorer_input:
taskset:
name: gsm8k
Expand All @@ -98,8 +98,7 @@ trainer:
trainer_config_path: examples/async_gsm8k/verl_config.yaml
```


You may run this example with the following command:
You can run this example with the following command:

```bash
bash examples/async_gsm8k/run.sh
Expand All @@ -110,3 +109,60 @@ The following plot shows the learning curve of GRPO in the asynchronous mode.
> We are continuously investigating other RL algorithms (e.g., [OPMD](./example_reasoning_advanced.md)) in the asynchronous mode.

![async](../../assets/async-curve.png)


Trinity-RFT also supports dynamic scaling in asynchronous mode. Continuing with the previous example, if an additional machine with 8 GPUs joins the Ray cluster during training, you can launch a new explorer using the following configuration `explorer_new.yaml`.

```yaml
# explorer_new.yaml
project: <project_name>
name: <experiment_name>
mode: explore
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
algorithm:
algorithm_type: grpo
repeat_times: 8
model:
model_path: /PATH/TO/MODEL/
cluster: # important
node_num: 1
gpu_per_node: 8
explorer:
name: 'explorer_new' # important
runner_num: 64
rollout_model:
engine_type: vllm_async
engine_num: 8
buffer:
total_epochs: 1
batch_size: 64
explorer_input:
taskset: # important
name: gsm8k
storage_type: file
path: /PATH/TO/DATASET/
format:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
temperature: 1.0
default_workflow_type: 'math_workflow'
trainer_input:
experience_buffer:
name: gsm8k_buffer
storage_type: queue
path: 'sqlite:///gsm8k.db'
synchronizer:
sync_method: 'checkpoint'
sync_interval: 10
# other configs are the same as explorer.yaml
```

The differences between `explorer_new.yaml` and `explorer.yaml` include:

- `cluster.node_num/gpu_per_node`: Specify the cluster configuration for the newly added explorer.
- `explorer.name`: The later-started explorer requires a different name than "explorer", which is the default name for the existing explorer.
- `explorer.rollout_model.engine_num/tensor_parallel_size`: Define the engine number and tensor parallel size to optimally utilize GPU resources.
- `buffer.explorer_input.taskset`: Provide another task dataset as input for the new explorer.

All other parameters remain the same as in `explorer.yaml`.
47 changes: 43 additions & 4 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ checkpoint_root_dir: /PATH/TO/CHECKPOINT
- `explore`: Only launches the explorer.
- `bench`: Used for benchmarking.
- `checkpoint_root_dir`: Root directory where all checkpoints and logs will be saved. Checkpoints for this experiment will be stored in `<checkpoint_root_dir>/<project>/<name>/`.
- `ray_namespace`: Namespace for the modules launched in the current experiment. If not specified, it will be set to `<project>/<name>`.

---

Expand Down Expand Up @@ -166,6 +167,9 @@ buffer:
eval_tasksets:
...

explorer_output:
...

trainer_input:
experience_buffer:
...
Expand Down Expand Up @@ -219,15 +223,15 @@ buffer:

The configuration for each task dataset is defined as follows:

- `name`: Name of the dataset. Name must be unique.
- `name`: Name of the dataset. This name will be used as the Ray actor's name, so it must be unique.
- `storage_type`: How the dataset is stored. Options: `file`, `queue`, `sql`.
- `file`: The dataset is stored in `jsonl`/`parquet` files. The data file organization is required to meet the huggingface standard. *We recommand using this storage type for most cases.*
- `queue`: The dataset is stored in a queue. The queue is a simple FIFO queue that stores the task dataset. *Do not use this storage type for task dataset unless you know what you are doing.*
- `sql`: The dataset is stored in a SQL database. *This type is unstable and will be optimized in the future versions.*
- `path`: The path to the task dataset.
- For `file` storage type, the path is the path to the directory that contains the task dataset files.
- For `file` storage type, the path points to the directory that contains the task dataset files.
- For `queue` storage type, the path is optional. You can back up the data in the queue by specifying a sqlite database path here.
- For `sql` storage type, the path is the path to the sqlite database file.
- For `sql` storage type, the path points to the sqlite database file.
- `subset_name`: The subset name of the task dataset. Default is `None`.
- `split`: The split of the task dataset. Default is `train`.
- `format`: Defines keys for prompts and responses in the dataset.
Expand All @@ -240,6 +244,34 @@ The configuration for each task dataset is defined as follows:
- `workflow_args`: A dictionary of arguments used to supplement dataset-level parameters.


### Explorer Output

In [`explore` mode](#global-configuration), since there is no trainer, users can configure an experience buffer via `buffer.explorer_output`, rather than using `buffer.trainer_input`, which will be introduced in the next section.

> For `both` and `train` modes, users should use `buffer.trainer_input` instead of `buffer.explorer_output`.

```yaml
buffer:
...
explorer_output:
name: countdown_buffer
storage_type: queue
path: sqlite:///countdown_buffer.db
wrap_in_ray: True
```

- `name`: The name of the experience buffer. This name will be used as the Ray actor's name, so it must be unique.
- `storage_type`: The storage type for the experience buffer.
- `queue`: Experience data is stored in a queue. This storage type is recommended for most use cases.
- `sql`: Experience data is stored in a SQL database. If your database only supports local access (e.g., SQLite), set `wrap_in_ray` to `True` to wrap the database in a Ray actor, enabling remote access from other nodes.
- `file`: Experience data is stored in a JSON file. This storage type should be used only for debugging purposes in `explore` mode.
- `path`: The path to the experience buffer.
- For `queue` storage type, this field is optional. You can specify a SQLite database or JSON file path here to back up the queue data.
- For `file` storage type, the path points to the directory containing the dataset files.
- For `sql` storage type, the path points to the SQLite database file.
- `wrap_in_ray`: Whether to wrap the experience buffer in a Ray actor. Only take effect when `storage_type` is `sql` or `file`. The `queue` storage always uses a Ray actor.


### Trainer Input

Defines the experience buffer and optional SFT warm-up dataset.
Expand All @@ -264,7 +296,7 @@ buffer:
sft_warmup_steps: 0
```

- `experience_buffer`: Experience replay buffer used by the trainer.
- `experience_buffer`: Experience buffer used by the trainer, which is logically equivalent to `buffer.explorer_output`.
- `sft_warmup_dataset`: Optional dataset used for pre-training (SFT warmup).
- `sft_warmup_steps`: Number of steps to use SFT warm-up before RL begins.

Expand All @@ -276,6 +308,7 @@ Controls the rollout models and workflow execution.

```yaml
explorer:
name: explorer
runner_num: 32
rollout_model:
engine_type: vllm_async
Expand All @@ -286,11 +319,13 @@ explorer:
tensor_parallel_size: 1
```

- `name`: Name of the explorer. This name will be used as the Ray actor's name, so it must be unique.
- `runner_num`: Number of parallel workflow runners.
- `rollout_model.engine_type`: Type of inference engine. Options: `vllm_async` (recommended), `vllm`.
- `rollout_model.engine_num`: Number of inference engines.
- `rollout_model.tensor_parallel_size`: Degree of tensor parallelism.
- `auxiliary_models`: Additional models used for custom workflows.

---

## Synchronizer Configuration
Expand All @@ -301,13 +336,15 @@ Controls how model weights are synchronized between trainer and explorer.
synchronizer:
sync_method: 'nccl'
sync_interval: 10
sync_offset: 0
sync_timeout: 1200
```

- `sync_method`: Method of synchronization. Options:
- `nccl`: Uses NCCL for fast synchronization. Supported for `both` mode.
- `checkpoint`: Loads latest model from disk. Supported for `train`, `explore`, or `bench` mode.
- `sync_interval`: Interval (in steps) of model weight synchronization between trainer and explorer.
- `sync_offset`: Offset (in steps) of model weight synchronization between trainer and explorer. The explorer can run `sync_offset` steps before the trainer starts training.
- `sync_timeout`: Timeout duration for synchronization.

---
Expand All @@ -318,12 +355,14 @@ Specifies the backend and behavior of the trainer.

```yaml
trainer:
name: trainer
trainer_type: 'verl'
save_interval: 100
trainer_config_path: 'examples/ppo_countdown/train_countdown.yaml'
trainer_config: null
```

- `name`: Name of the trainer. This name will be used as the Ray actor's name, so it must be unique.
- `trainer_type`: Trainer backend implementation. Currently only supports `verl`.
- `save_interval`: Frequency (in steps) at which to save model checkpoints.
- `trainer_config_path`: The path to the trainer configuration file.
Expand Down
5 changes: 2 additions & 3 deletions docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

This guide introduces how to develop new modules in Trinity-RFT and provides relevant development guidelines.

Trinity-RFT consists of three main modules: **Explorer**, **Trainer** and **Buffer**.
We decouple the RL pipeline into three modules to make it easier to customize and extend.
Below is a table summarizing the modules and components that developers with different tragets need to focus on.
In Trinity-RFT, we decompose the RL pipeline into three main modules (**Explorer**, **Trainer** and **Buffer**) to facilitate customization and extension.
Below is a table summarizing the modules and components that developers with different targets need to focus on.

| Development Target | Core Module | Key Component |
|--------------------|-------------|---------------|
Expand Down
2 changes: 1 addition & 1 deletion tests/buffer/file_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_file_buffer(self):
# test writer
writer = JSONWriter(meta, None)
writer.write(data)
writer.finish()
writer.release()

# test reader
meta.path = self.temp_output_path
Expand Down
3 changes: 2 additions & 1 deletion tests/buffer/queue_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def test_queue_buffer(self):
)
writer = QueueWriter(meta, config)
reader = QueueReader(meta, config)
self.assertEqual(writer.acquire(), 1)
exps = [
Experience(
tokens=torch.tensor([float(j) for j in range(i + 1)]),
Expand Down Expand Up @@ -59,7 +60,7 @@ def test_queue_buffer(self):
)
exps = reader.read(batch_size=put_batch_size * 2)
self.assertEqual(len(exps), put_batch_size * 2)
writer.finish()
self.assertEqual(writer.release(), 0)
self.assertRaises(StopIteration, reader.read)
with open(BUFFER_FILE_PATH, "r") as f:
self.assertEqual(len(f.readlines()), total_num + put_batch_size * 2)
Expand Down
3 changes: 3 additions & 0 deletions tests/buffer/sql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_create_sql_buffer(self) -> None:
)
for i in range(1, put_batch_size + 1)
]
self.assertEqual(sql_writer.acquire(), 1)
for _ in range(total_num // put_batch_size):
sql_writer.write(exps)
for _ in range(total_num // read_batch_size):
Expand All @@ -65,3 +66,5 @@ def test_create_sql_buffer(self) -> None:
self.assertEqual(len(exps), put_batch_size * 2)
db_wrapper = ray.get_actor("sql-test_buffer")
self.assertIsNotNone(db_wrapper)
self.assertEqual(sql_writer.release(), 0)
self.assertRaises(StopIteration, sql_reader.read)
3 changes: 0 additions & 3 deletions tests/explorer/runner_pool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,6 @@ def test_runner_pool_with_auxiliary_models(self):
)

# `auxiliary_models`
st = time.time()
status = pool.get_next_unorder()
et = time.time()
self.assertTrue(et - st < 1)
self.assertEqual(len(status), 1)
self.assertTrue(status[0].ok)
2 changes: 1 addition & 1 deletion tests/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_checkpoint_path() -> str:
def get_unittest_dataset_config(
dataset_name: str = "countdown", split: str = "train"
) -> StorageConfig:
"""Countdown sample dataset for 8 steps"""
"""Countdown dataset with 16 samples."""
if dataset_name == "countdown" or dataset_name == "copy_countdown":
return StorageConfig(
name=dataset_name,
Expand Down
Loading