diff --git a/docs/sphinx_doc/source/tutorial/example_async_mode.md b/docs/sphinx_doc/source/tutorial/example_async_mode.md index 1f9a9c8665..70ca66e2b2 100644 --- a/docs/sphinx_doc/source/tutorial/example_async_mode.md +++ b/docs/sphinx_doc/source/tutorial/example_async_mode.md @@ -1,17 +1,17 @@ # Asynchronous RFT -This example shows how to run RFT in a fully asynchronous mode with the GRPO algorithm, Qwen2.5-1.5B-Instruct model and GSM8K dataset. +This example demonstrates how to run RFT in fully asynchronous mode using the GRPO algorithm, Qwen2.5-1.5B-Instruct model, and GSM8K dataset. -Trinity-RFT supports an asynchronous mode by running the trainer and explorer in separate processes. +Trinity-RFT supports Asynchronous RFT by running the trainer and explorer in separate processes. -For this purpose, we prepare two main config files: [`explorer.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/async_gsm8k/explorer.yaml) and [`trainer.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/async_gsm8k/trainer.yaml). -The main difference between them is that in `explorer.yaml` we set `mode` as `explore`, while in `trainer.yaml` we set `mode` as `train`. +For this purpose, we provide two main configuration files: [`explorer.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/async_gsm8k/explorer.yaml) and [`trainer.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/async_gsm8k/trainer.yaml). +The primary difference between them is that in `explorer.yaml` we set `mode` as `explore`, while in `trainer.yaml` we set `mode` as `train`. The model weights of the explorer and trainer are synchronized once every `sync_interval * batch_size` tasks. -Suppose we have a node of 8 GPUs; we use 4 GPUs for the trainer and 4 GPUs for the explorer. -Some important setups of `explorer.yaml` are listed in the following: +Assuming we have a node with 8 GPUs, we allocate 4 GPUs for the trainer and 4 GPUs for the explorer. Key configurations in `explorer.yaml` are as follows: ```yaml +# explorer.yaml project: name: mode: explore @@ -26,7 +26,7 @@ cluster: gpu_per_node: 4 buffer: total_epochs: 1 - batch_size: 96 + batch_size: 64 explorer_input: taskset: name: gsm8k @@ -45,7 +45,6 @@ buffer: storage_type: queue path: 'sqlite:///gsm8k.db' explorer: - eval_interval: 10 runner_num: 32 rollout_model: engine_type: vllm_async @@ -57,9 +56,10 @@ trainer: trainer_config_path: examples/async_gsm8k/verl_config.yaml ``` -Some important setups of `trainer.yaml` are listed in the following: +Key configurations in `trainer.yaml` are as follows: ```yaml +# trainer.yaml project: name: mode: train @@ -74,7 +74,7 @@ cluster: gpu_per_node: 4 buffer: total_epochs: 1 - batch_size: 96 + batch_size: 64 explorer_input: taskset: name: gsm8k @@ -98,8 +98,7 @@ trainer: trainer_config_path: examples/async_gsm8k/verl_config.yaml ``` - -You may run this example with the following command: +You can run this example with the following command: ```bash bash examples/async_gsm8k/run.sh @@ -110,3 +109,60 @@ The following plot shows the learning curve of GRPO in the asynchronous mode. > We are continuously investigating other RL algorithms (e.g., [OPMD](./example_reasoning_advanced.md)) in the asynchronous mode. ![async](../../assets/async-curve.png) + + +Trinity-RFT also supports dynamic scaling in asynchronous mode. Continuing with the previous example, if an additional machine with 8 GPUs joins the Ray cluster during training, you can launch a new explorer using the following configuration `explorer_new.yaml`. + +```yaml +# explorer_new.yaml +project: +name: +mode: explore +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +algorithm: + algorithm_type: grpo + repeat_times: 8 +model: + model_path: /PATH/TO/MODEL/ +cluster: # important + node_num: 1 + gpu_per_node: 8 +explorer: + name: 'explorer_new' # important + runner_num: 64 + rollout_model: + engine_type: vllm_async + engine_num: 8 +buffer: + total_epochs: 1 + batch_size: 64 + explorer_input: + taskset: # important + name: gsm8k + storage_type: file + path: /PATH/TO/DATASET/ + format: + prompt_key: 'question' + response_key: 'answer' + rollout_args: + temperature: 1.0 + default_workflow_type: 'math_workflow' + trainer_input: + experience_buffer: + name: gsm8k_buffer + storage_type: queue + path: 'sqlite:///gsm8k.db' +synchronizer: + sync_method: 'checkpoint' + sync_interval: 10 +# other configs are the same as explorer.yaml +``` + +The differences between `explorer_new.yaml` and `explorer.yaml` include: + +- `cluster.node_num/gpu_per_node`: Specify the cluster configuration for the newly added explorer. +- `explorer.name`: The later-started explorer requires a different name than "explorer", which is the default name for the existing explorer. +- `explorer.rollout_model.engine_num/tensor_parallel_size`: Define the engine number and tensor parallel size to optimally utilize GPU resources. +- `buffer.explorer_input.taskset`: Provide another task dataset as input for the new explorer. + +All other parameters remain the same as in `explorer.yaml`. diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index f6a6d8c780..88d925f786 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -68,6 +68,7 @@ checkpoint_root_dir: /PATH/TO/CHECKPOINT - `explore`: Only launches the explorer. - `bench`: Used for benchmarking. - `checkpoint_root_dir`: Root directory where all checkpoints and logs will be saved. Checkpoints for this experiment will be stored in `///`. +- `ray_namespace`: Namespace for the modules launched in the current experiment. If not specified, it will be set to `/`. --- @@ -166,6 +167,9 @@ buffer: eval_tasksets: ... + explorer_output: + ... + trainer_input: experience_buffer: ... @@ -219,15 +223,15 @@ buffer: The configuration for each task dataset is defined as follows: -- `name`: Name of the dataset. Name must be unique. +- `name`: Name of the dataset. This name will be used as the Ray actor's name, so it must be unique. - `storage_type`: How the dataset is stored. Options: `file`, `queue`, `sql`. - `file`: The dataset is stored in `jsonl`/`parquet` files. The data file organization is required to meet the huggingface standard. *We recommand using this storage type for most cases.* - `queue`: The dataset is stored in a queue. The queue is a simple FIFO queue that stores the task dataset. *Do not use this storage type for task dataset unless you know what you are doing.* - `sql`: The dataset is stored in a SQL database. *This type is unstable and will be optimized in the future versions.* - `path`: The path to the task dataset. - - For `file` storage type, the path is the path to the directory that contains the task dataset files. + - For `file` storage type, the path points to the directory that contains the task dataset files. - For `queue` storage type, the path is optional. You can back up the data in the queue by specifying a sqlite database path here. - - For `sql` storage type, the path is the path to the sqlite database file. + - For `sql` storage type, the path points to the sqlite database file. - `subset_name`: The subset name of the task dataset. Default is `None`. - `split`: The split of the task dataset. Default is `train`. - `format`: Defines keys for prompts and responses in the dataset. @@ -240,6 +244,34 @@ The configuration for each task dataset is defined as follows: - `workflow_args`: A dictionary of arguments used to supplement dataset-level parameters. +### Explorer Output + +In [`explore` mode](#global-configuration), since there is no trainer, users can configure an experience buffer via `buffer.explorer_output`, rather than using `buffer.trainer_input`, which will be introduced in the next section. + +> For `both` and `train` modes, users should use `buffer.trainer_input` instead of `buffer.explorer_output`. + +```yaml +buffer: + ... + explorer_output: + name: countdown_buffer + storage_type: queue + path: sqlite:///countdown_buffer.db + wrap_in_ray: True +``` + +- `name`: The name of the experience buffer. This name will be used as the Ray actor's name, so it must be unique. +- `storage_type`: The storage type for the experience buffer. + - `queue`: Experience data is stored in a queue. This storage type is recommended for most use cases. + - `sql`: Experience data is stored in a SQL database. If your database only supports local access (e.g., SQLite), set `wrap_in_ray` to `True` to wrap the database in a Ray actor, enabling remote access from other nodes. + - `file`: Experience data is stored in a JSON file. This storage type should be used only for debugging purposes in `explore` mode. +- `path`: The path to the experience buffer. + - For `queue` storage type, this field is optional. You can specify a SQLite database or JSON file path here to back up the queue data. + - For `file` storage type, the path points to the directory containing the dataset files. + - For `sql` storage type, the path points to the SQLite database file. +- `wrap_in_ray`: Whether to wrap the experience buffer in a Ray actor. Only take effect when `storage_type` is `sql` or `file`. The `queue` storage always uses a Ray actor. + + ### Trainer Input Defines the experience buffer and optional SFT warm-up dataset. @@ -264,7 +296,7 @@ buffer: sft_warmup_steps: 0 ``` -- `experience_buffer`: Experience replay buffer used by the trainer. +- `experience_buffer`: Experience buffer used by the trainer, which is logically equivalent to `buffer.explorer_output`. - `sft_warmup_dataset`: Optional dataset used for pre-training (SFT warmup). - `sft_warmup_steps`: Number of steps to use SFT warm-up before RL begins. @@ -276,6 +308,7 @@ Controls the rollout models and workflow execution. ```yaml explorer: + name: explorer runner_num: 32 rollout_model: engine_type: vllm_async @@ -286,11 +319,13 @@ explorer: tensor_parallel_size: 1 ``` +- `name`: Name of the explorer. This name will be used as the Ray actor's name, so it must be unique. - `runner_num`: Number of parallel workflow runners. - `rollout_model.engine_type`: Type of inference engine. Options: `vllm_async` (recommended), `vllm`. - `rollout_model.engine_num`: Number of inference engines. - `rollout_model.tensor_parallel_size`: Degree of tensor parallelism. - `auxiliary_models`: Additional models used for custom workflows. + --- ## Synchronizer Configuration @@ -301,6 +336,7 @@ Controls how model weights are synchronized between trainer and explorer. synchronizer: sync_method: 'nccl' sync_interval: 10 + sync_offset: 0 sync_timeout: 1200 ``` @@ -308,6 +344,7 @@ synchronizer: - `nccl`: Uses NCCL for fast synchronization. Supported for `both` mode. - `checkpoint`: Loads latest model from disk. Supported for `train`, `explore`, or `bench` mode. - `sync_interval`: Interval (in steps) of model weight synchronization between trainer and explorer. +- `sync_offset`: Offset (in steps) of model weight synchronization between trainer and explorer. The explorer can run `sync_offset` steps before the trainer starts training. - `sync_timeout`: Timeout duration for synchronization. --- @@ -318,12 +355,14 @@ Specifies the backend and behavior of the trainer. ```yaml trainer: + name: trainer trainer_type: 'verl' save_interval: 100 trainer_config_path: 'examples/ppo_countdown/train_countdown.yaml' trainer_config: null ``` +- `name`: Name of the trainer. This name will be used as the Ray actor's name, so it must be unique. - `trainer_type`: Trainer backend implementation. Currently only supports `verl`. - `save_interval`: Frequency (in steps) at which to save model checkpoints. - `trainer_config_path`: The path to the trainer configuration file. diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md index e07e6bb3dc..fb75d084b1 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md +++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md @@ -2,9 +2,8 @@ This guide introduces how to develop new modules in Trinity-RFT and provides relevant development guidelines. -Trinity-RFT consists of three main modules: **Explorer**, **Trainer** and **Buffer**. -We decouple the RL pipeline into three modules to make it easier to customize and extend. -Below is a table summarizing the modules and components that developers with different tragets need to focus on. +In Trinity-RFT, we decompose the RL pipeline into three main modules (**Explorer**, **Trainer** and **Buffer**) to facilitate customization and extension. +Below is a table summarizing the modules and components that developers with different targets need to focus on. | Development Target | Core Module | Key Component | |--------------------|-------------|---------------| diff --git a/tests/buffer/file_test.py b/tests/buffer/file_test.py index e53669a850..2882dd8e0f 100644 --- a/tests/buffer/file_test.py +++ b/tests/buffer/file_test.py @@ -47,7 +47,7 @@ def test_file_buffer(self): # test writer writer = JSONWriter(meta, None) writer.write(data) - writer.finish() + writer.release() # test reader meta.path = self.temp_output_path diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index 03e96e4291..23271c6158 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -30,6 +30,7 @@ def test_queue_buffer(self): ) writer = QueueWriter(meta, config) reader = QueueReader(meta, config) + self.assertEqual(writer.acquire(), 1) exps = [ Experience( tokens=torch.tensor([float(j) for j in range(i + 1)]), @@ -59,7 +60,7 @@ def test_queue_buffer(self): ) exps = reader.read(batch_size=put_batch_size * 2) self.assertEqual(len(exps), put_batch_size * 2) - writer.finish() + self.assertEqual(writer.release(), 0) self.assertRaises(StopIteration, reader.read) with open(BUFFER_FILE_PATH, "r") as f: self.assertEqual(len(f.readlines()), total_num + put_batch_size * 2) diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index 56305be671..e40a91b4c7 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -42,6 +42,7 @@ def test_create_sql_buffer(self) -> None: ) for i in range(1, put_batch_size + 1) ] + self.assertEqual(sql_writer.acquire(), 1) for _ in range(total_num // put_batch_size): sql_writer.write(exps) for _ in range(total_num // read_batch_size): @@ -65,3 +66,5 @@ def test_create_sql_buffer(self) -> None: self.assertEqual(len(exps), put_batch_size * 2) db_wrapper = ray.get_actor("sql-test_buffer") self.assertIsNotNone(db_wrapper) + self.assertEqual(sql_writer.release(), 0) + self.assertRaises(StopIteration, sql_reader.read) diff --git a/tests/explorer/runner_pool_test.py b/tests/explorer/runner_pool_test.py index 52f961bda4..735255ecf2 100644 --- a/tests/explorer/runner_pool_test.py +++ b/tests/explorer/runner_pool_test.py @@ -250,9 +250,6 @@ def test_runner_pool_with_auxiliary_models(self): ) # `auxiliary_models` - st = time.time() status = pool.get_next_unorder() - et = time.time() - self.assertTrue(et - st < 1) self.assertEqual(len(status), 1) self.assertTrue(status[0].ok) diff --git a/tests/tools.py b/tests/tools.py index 209b5eb1c2..0b4ffd5750 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -42,7 +42,7 @@ def get_checkpoint_path() -> str: def get_unittest_dataset_config( dataset_name: str = "countdown", split: str = "train" ) -> StorageConfig: - """Countdown sample dataset for 8 steps""" + """Countdown dataset with 16 samples.""" if dataset_name == "countdown" or dataset_name == "copy_countdown": return StorageConfig( name=dataset_name, diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 32d19e9190..250ea3eb40 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -1,7 +1,11 @@ """Tests for trainer.""" +import multiprocessing import os import shutil +import time +import unittest from abc import abstractmethod +from copy import deepcopy from datetime import datetime import ray @@ -14,8 +18,11 @@ get_template_config, get_unittest_dataset_config, ) -from trinity.cli.launcher import bench, both, train -from trinity.common.constants import SyncMethod +from trinity.cli.launcher import bench, both, explore, train +from trinity.common.config import Config, StorageConfig +from trinity.common.constants import StorageType, SyncMethod +from trinity.common.models.utils import get_checkpoint_dir_with_step_num +from trinity.manager.manager import CacheManager class BaseTrainerCase(RayUnittestBase): @@ -149,7 +156,6 @@ def test_trainer(self): response_metrics = parser.metric_list("response_length") self.assertTrue(len(response_metrics) > 0) self.assertEqual(parser.metric_max_step(response_metrics[0]), 4) - ray.timeline(filename="timeline.json") ray.shutdown(_exiting_interpreter=True) # check checkpoint from trinity.common.models.utils import get_checkpoint_dir_with_step_num @@ -262,3 +268,128 @@ def test_trainer(self): def tearDown(self): # remove dir only when the test passed shutil.rmtree(self.config.checkpoint_job_dir) + + +def run_trainer(config: Config) -> None: + ray.init(namespace=config.ray_namespace) + train(config) + + +def run_explorer(config: Config) -> None: + ray.init(namespace=config.ray_namespace) + explore(config) + + +class TestFullyAsyncMode(unittest.TestCase): + def setUp(self): + if multiprocessing.get_start_method(allow_none=True) != "spawn": + multiprocessing.set_start_method("spawn", force=True) + + def test_fully_async_mode(self): + config = get_template_config() + config.project = "unittest" + config.name = f"fully_async_{datetime.now().strftime('%Y%m%d%H%M%S')}" + config.checkpoint_root_dir = get_checkpoint_path() + config.buffer.total_epochs = 1 + config.buffer.batch_size = 4 + config.cluster.gpu_per_node = 2 + config.cluster.node_num = 1 + config.model.model_path = get_model_path() + config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") + config.buffer.trainer_input.experience_buffer = StorageConfig( + name="exp_buffer", + storage_type=StorageType.QUEUE, + wrap_in_ray=True, + ) + config.synchronizer.sync_method = SyncMethod.CHECKPOINT + config.synchronizer.sync_interval = 8 + config.monitor.monitor_type = "tensorboard" + trainer_config = deepcopy(config) + trainer_config.mode = "train" + trainer_config.check_and_update() + + explorer1_config = deepcopy(config) + explorer1_config.mode = "explore" + explorer1_config.explorer.name = "explorer1" + config.cluster.gpu_per_node = 1 + config.cluster.node_num = 1 + explorer1_config.explorer.rollout_model.engine_num = 1 + explorer1_config.explorer.rollout_model.tensor_parallel_size = 1 + explorer1_config.explorer.runner_num = 4 + explorer1_config.buffer.explorer_output = StorageConfig( + name="exp_buffer", + storage_type=StorageType.QUEUE, + wrap_in_ray=True, + ) + explorer2_config = deepcopy(explorer1_config) + explorer1_config.check_and_update() + + trainer_process = multiprocessing.Process(target=run_trainer, args=(trainer_config,)) + trainer_process.start() + + ray.init(ignore_reinit_error=True) + while True: + try: + ray.get_actor("queue-exp_buffer", namespace=trainer_config.ray_namespace) + break + except ValueError: + print("waiting for trainer to start.") + time.sleep(5) + + explorer_process_1 = multiprocessing.Process(target=run_explorer, args=(explorer1_config,)) + explorer_process_1.start() + + time.sleep(20) + explorer2_config.explorer.name = "explorer2" + explorer2_config.check_and_update() + explorer_process_2 = multiprocessing.Process(target=run_explorer, args=(explorer2_config,)) + explorer_process_2.start() + + explorer_process_1.join() + explorer_process_2.join() + + # wait for trainer process to finish. + trainer_process.join(timeout=200) + + # check the tensorboard + parser = TensorBoardParser( + os.path.join(trainer_config.monitor.cache_dir, "tensorboard", "trainer") + ) + actor_metrics = parser.metric_list("actor") + self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8) + parser = TensorBoardParser( + os.path.join(explorer1_config.monitor.cache_dir, "tensorboard", "explorer1") + ) + rollout_metrics = parser.metric_list("rollout") + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) + parser = TensorBoardParser( + os.path.join(explorer2_config.monitor.cache_dir, "tensorboard", "explorer2") + ) + rollout_metrics = parser.metric_list("rollout") + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) + # check the checkpoint + explorer1_cache = CacheManager(explorer1_config) + cache = explorer1_cache.load_explorer() + self.assertEqual(cache["latest_iteration"], 4) + explorer2_cache = CacheManager(explorer2_config) + cache = explorer2_cache.load_explorer() + self.assertEqual(cache["latest_iteration"], 4) + self.assertIsNotNone( + get_checkpoint_dir_with_step_num( + checkpoint_root_path=explorer1_config.checkpoint_job_dir, + trainer_type="verl", + step_num=8, + ) + ) + self.assertIsNotNone( + get_checkpoint_dir_with_step_num( + checkpoint_root_path=explorer2_config.checkpoint_job_dir, + trainer_type="verl", + step_num=8, + ) + ) + ray.shutdown() + + def tearDown(self): + checkpoint_path = get_checkpoint_path() + shutil.rmtree(os.path.join(checkpoint_path, "unittest")) diff --git a/trinity/buffer/buffer_writer.py b/trinity/buffer/buffer_writer.py index ac245f50b6..13079ffb76 100644 --- a/trinity/buffer/buffer_writer.py +++ b/trinity/buffer/buffer_writer.py @@ -11,5 +11,17 @@ def write(self, data: List) -> None: """Write to buffer.""" @abstractmethod - def finish(self) -> None: - """Finish writing.""" + def acquire(self) -> int: + """Acquire the buffer writer. + + Returns: + `int`: The reference count of the buffer after acquiring. + """ + + @abstractmethod + def release(self) -> int: + """Release the buffer writer. After release, the buffer writer can not be used again. + + Returns: + `int`: The reference count of the buffer after releasing. + """ diff --git a/trinity/buffer/queue.py b/trinity/buffer/queue.py index a3db72ef90..9cdd99a592 100644 --- a/trinity/buffer/queue.py +++ b/trinity/buffer/queue.py @@ -3,6 +3,8 @@ from copy import deepcopy from typing import List +import ray + from trinity.buffer.writer.file_writer import JSONWriter from trinity.buffer.writer.sql_writer import SQLWriter from trinity.common.config import BufferConfig, StorageConfig @@ -44,6 +46,19 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: st_config.storage_type = StorageType.FILE self.writer = JSONWriter(st_config, self.config) self.logger.warning(f"Save experiences in {st_config.path}.") + self.ref_count = 0 + + async def acquire(self) -> int: + self.ref_count += 1 + return self.ref_count + + async def release(self) -> int: + """Release the queue.""" + self.ref_count -= 1 + if self.ref_count <= 0: + await self.queue.put(self.FINISH_MESSAGE) + self.writer.release() + return self.ref_count def length(self) -> int: """The length of the queue.""" @@ -55,10 +70,6 @@ async def put_batch(self, exp_list: List) -> None: if self.writer is not None: self.writer.write(exp_list) - async def finish(self) -> None: - """Stop the queue.""" - await self.queue.put(self.FINISH_MESSAGE) - async def get_batch(self, batch_size: int) -> List: """Get batch of experience.""" batch = [] @@ -70,3 +81,16 @@ async def get_batch(self, batch_size: int) -> List: if len(batch) >= batch_size: break return batch + + @classmethod + def get_actor(cls, storage_config: StorageConfig, config: BufferConfig): + """Get the queue actor.""" + return ( + ray.remote(cls) + .options( + name=f"queue-{storage_config.name}", + namespace=ray.get_runtime_context().namespace, + get_if_exists=True, + ) + .remote(storage_config, config) + ) diff --git a/trinity/buffer/ray_wrapper.py b/trinity/buffer/ray_wrapper.py index 71e9102999..ba736b02bc 100644 --- a/trinity/buffer/ray_wrapper.py +++ b/trinity/buffer/ray_wrapper.py @@ -47,6 +47,8 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: self.batch_size = config.read_batch_size self.max_retry_times = config.max_retry_times self.max_retry_interval = config.max_retry_interval + self.ref_count = 0 + self.stopped = False @classmethod def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): @@ -71,6 +73,9 @@ def write(self, data: list) -> None: def read( self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None ) -> List: + if self.stopped: + raise StopIteration() + if strategy is None: strategy = ReadStrategy.LFU @@ -114,6 +119,16 @@ def read( self.logger.info(f"first response_text = {exp_list[0].response_text}") return exp_list + def acquire(self) -> int: + self.ref_count += 1 + return self.ref_count + + def release(self) -> int: + self.ref_count -= 1 + if self.ref_count <= 0: + self.stopped = True + return self.ref_count + class _Encoder(json.JSONEncoder): def default(self, o): @@ -147,6 +162,7 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: os.makedirs(path_dir, exist_ok=True) self.file = open(storage_config.path, "a", encoding="utf-8") self.encoder = _Encoder(ensure_ascii=False) + self.ref_count = 0 @classmethod def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): @@ -174,5 +190,12 @@ def read(self) -> List: "read() is not implemented for FileWrapper, please use QUEUE instead" ) - def finish(self) -> None: - self.file.close() + def acquire(self) -> int: + self.ref_count += 1 + return self.ref_count + + def release(self) -> int: + self.ref_count -= 1 + if self.ref_count <= 0: + self.file.close() + return self.ref_count diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index 271c2931e2..6591ddde4a 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -19,15 +19,7 @@ class QueueReader(BufferReader): def __init__(self, storage_config: StorageConfig, config: BufferConfig): assert storage_config.storage_type == StorageType.QUEUE self.read_batch_size = config.read_batch_size - self.queue = ( - ray.remote(QueueActor) - .options( - name=f"queue-{storage_config.name}", - namespace=ray.get_runtime_context().namespace, - get_if_exists=True, - ) - .remote(storage_config, config) - ) + self.queue = QueueActor.get_actor(storage_config, config) def read( self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None diff --git a/trinity/buffer/writer/file_writer.py b/trinity/buffer/writer/file_writer.py index 0fc4929ca5..16ec96d0a9 100644 --- a/trinity/buffer/writer/file_writer.py +++ b/trinity/buffer/writer/file_writer.py @@ -20,8 +20,15 @@ def write(self, data: List) -> None: else: self.writer.write(data) - def finish(self): + def acquire(self) -> int: if self.wrap_in_ray: - ray.get(self.writer.finish.remote()) + return ray.get(self.writer.acquire()) else: - self.writer.finish() + return 0 + + def release(self) -> int: + if self.wrap_in_ray: + return ray.get(self.writer.release.remote()) + else: + self.writer.release() + return 0 diff --git a/trinity/buffer/writer/queue_writer.py b/trinity/buffer/writer/queue_writer.py index ec2316a0ec..7b12fab4c1 100644 --- a/trinity/buffer/writer/queue_writer.py +++ b/trinity/buffer/writer/queue_writer.py @@ -18,18 +18,13 @@ class QueueWriter(BufferWriter): def __init__(self, meta: StorageConfig, config: BufferConfig): assert meta.storage_type == StorageType.QUEUE self.config = config - self.queue = ( - ray.remote(QueueActor) - .options( - name=f"queue-{meta.name}", - namespace=ray.get_runtime_context().namespace, - get_if_exists=True, - ) - .remote(meta, config) - ) + self.queue = QueueActor.get_actor(meta, config) def write(self, data: List) -> None: ray.get(self.queue.put_batch.remote(data)) - def finish(self): - ray.get(self.queue.finish.remote()) + def acquire(self) -> int: + return ray.get(self.queue.acquire.remote()) + + def release(self) -> int: + return ray.get(self.queue.release.remote()) diff --git a/trinity/buffer/writer/sql_writer.py b/trinity/buffer/writer/sql_writer.py index 8864dc9b82..95344d4447 100644 --- a/trinity/buffer/writer/sql_writer.py +++ b/trinity/buffer/writer/sql_writer.py @@ -23,6 +23,15 @@ def write(self, data: list) -> None: else: self.db_wrapper.write(data) - def finish(self) -> None: - # TODO: implement this - pass + def acquire(self) -> int: + if self.wrap_in_ray: + return ray.get(self.db_wrapper.acquire.remote()) + else: + return 0 + + def release(self) -> int: + if self.wrap_in_ray: + return ray.get(self.db_wrapper.release.remote()) + else: + self.db_wrapper.release() + return 0 diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index e4123820de..15b669d61c 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -9,7 +9,6 @@ import ray from trinity.common.config import Config, DataPipelineConfig, load_config -from trinity.common.constants import EXPLORER_NAME, TRAINER_NAME from trinity.explorer.explorer import Explorer from trinity.trainer.trainer import Trainer from trinity.utils.log import get_logger @@ -23,7 +22,7 @@ def bench(config: Config) -> None: explorer = ( ray.remote(Explorer) .options( - name=EXPLORER_NAME, + name=config.explorer.name, namespace=ray.get_runtime_context().namespace, ) .remote(config) @@ -44,7 +43,7 @@ def explore(config: Config) -> None: explorer = ( ray.remote(Explorer) .options( - name=EXPLORER_NAME, + name=config.explorer.name, namespace=ray.get_runtime_context().namespace, ) .remote(config) @@ -64,7 +63,7 @@ def train(config: Config) -> None: trainer = ( ray.remote(Trainer) .options( - name=TRAINER_NAME, + name=config.trainer.name, namespace=ray.get_runtime_context().namespace, ) .remote(config) @@ -92,7 +91,7 @@ def both(config: Config) -> None: explorer = ( ray.remote(Explorer) .options( - name=EXPLORER_NAME, + name=config.explorer.name, namespace=namespace, ) .remote(config) @@ -100,7 +99,7 @@ def both(config: Config) -> None: trainer = ( ray.remote(Trainer) .options( - name=TRAINER_NAME, + name=config.trainer.name, namespace=namespace, ) .remote(config) @@ -127,7 +126,7 @@ def both(config: Config) -> None: ) ready = ray.get(ready_ref[0]) - if ready == TRAINER_NAME: + if ready == config.trainer.name: logger.info( "===========================================================\n" "> Launcher detected that the `Trainer` process has finished.\n" @@ -135,7 +134,7 @@ def both(config: Config) -> None: "===========================================================" ) ray.wait(wait_ref, timeout=5) - elif ready == EXPLORER_NAME: + elif ready == config.explorer.name: logger.info( "============================================================\n" "> Launcher detected that the `Explorer` process has finished.\n" diff --git a/trinity/common/config.py b/trinity/common/config.py index 5d60cf8c4c..e6130395b2 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -7,6 +7,8 @@ from omegaconf import OmegaConf from trinity.common.constants import ( + EXPLORER_NAME, + TRAINER_NAME, PromptType, ReadStrategy, StorageType, @@ -79,7 +81,7 @@ class StorageConfig: format: FormatConfig = field(default_factory=FormatConfig) index: int = 0 - # used for StorageType.SQL + # used for StorageType.SQL/FILE wrap_in_ray: bool = True # used for StorageType.QUEUE @@ -279,6 +281,7 @@ class BufferConfig: class ExplorerConfig: """Config for explorer.""" + name: str = EXPLORER_NAME # for workflow runner # number of workflow runners. # For sync engine (vllm), it should be equal to `engine_num`. @@ -300,6 +303,7 @@ class ExplorerConfig: @dataclass class TrainerConfig: + name: str = TRAINER_NAME trainer_type: str = "verl" save_interval: int = 0 enable_preview: bool = True # enable rollout preview in wandb @@ -582,7 +586,7 @@ def check_and_update(self) -> None: # noqa: C901 # set namespace if self.ray_namespace is None or len(self.ray_namespace) == 0: - self.ray_namespace = f"{self.project}-{self.name}" + self.ray_namespace = f"{self.project}/{self.name}" # check algorithm self._check_algorithm() diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py index 79c0cfae01..0806bc9c7d 100644 --- a/trinity/common/models/vllm_async_model.py +++ b/trinity/common/models/vllm_async_model.py @@ -282,6 +282,7 @@ async def init_process_group( rank_offset: int, world_size: int, group_name: str, + explorer_name: str, backend: str = "nccl", timeout: int = 1200, update_with_checkpoint: bool = True, @@ -299,6 +300,7 @@ async def init_process_group( timeout, update_with_checkpoint, state_dict_meta, + explorer_name, ray.get_runtime_context().namespace, ), ) diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 3efe88b000..59211f198a 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -96,6 +96,7 @@ def init_process_group( rank_offset: int, world_size: int, group_name: str, + explorer_name: str, backend: str = "nccl", timeout: int = 1200, update_with_checkpoint: bool = True, @@ -113,6 +114,7 @@ def init_process_group( timeout, update_with_checkpoint, state_dict_meta, + explorer_name, ray.get_runtime_context().namespace, ), ) diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index 2a156b8a2a..235ac1b013 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -4,7 +4,6 @@ import torch import torch.distributed -from trinity.common.constants import EXPLORER_NAME from trinity.utils.distributed import init_process_group, is_ipv6_address from trinity.utils.log import get_logger @@ -23,6 +22,7 @@ def init_process_group( timeout: int = 1200, update_with_checkpoint: bool = True, state_dict_meta: list = None, + explorer_name: str = None, namespace: str = None, ): """Init torch process group for model weights update""" @@ -53,6 +53,7 @@ def init_process_group( group_name=group_name, ) logger.info("vLLM init_process_group finished.") + self._explorer_name = explorer_name self._namespace = namespace self._explorer_actor = None @@ -63,7 +64,9 @@ def update_weight(self): """Broadcast weight to all vllm workers from source rank 0 (actor model)""" assert self._state_dict_meta is not None if self._explorer_actor is None: - self._explorer_actor = ray.get_actor(name=EXPLORER_NAME, namespace=self._namespace) + self._explorer_actor = ray.get_actor( + name=self._explorer_name, namespace=self._namespace + ) for name, dtype_str, shape in self._state_dict_meta: if self._weight_update_rank == 0: weight = ray.get(self._explorer_actor.get_weight.remote(name)) diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 1ec0653503..0d2d3bf8c1 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -5,6 +5,7 @@ from omegaconf import OmegaConf from trinity.common.config import BufferConfig, Config, SynchronizerConfig +from trinity.common.constants import EXPLORER_NAME from trinity.utils.log import get_logger logger = get_logger(__name__) @@ -119,6 +120,7 @@ class ActorRolloutRef: ref: Ref = field(default_factory=Ref) rollout: Rollout = field(default_factory=Rollout) synchronizer: Optional[SynchronizerConfig] = None + explorer_name: str = EXPLORER_NAME @dataclass @@ -298,6 +300,7 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.synchronizer = config.synchronizer self.actor_rollout_ref.synchronizer = config.synchronizer + self.actor_rollout_ref.explorer_name = config.explorer.name # Actor / Critic config self.actor_rollout_ref.model.path = config.model.model_path diff --git a/trinity/data/core/dataset.py b/trinity/data/core/dataset.py index 93be832cc7..6b6d126f9b 100644 --- a/trinity/data/core/dataset.py +++ b/trinity/data/core/dataset.py @@ -84,7 +84,7 @@ def write_to_buffer( buffer_config = self.buffer_config output_buffer = get_buffer_writer(output_storage_config, buffer_config) output_buffer.write(self.data.to_list()) - output_buffer.finish() + output_buffer.release() self.data = Dataset.from_list([]) def to_parquet(self, path: str): diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 31ade5f84b..a35afa4c87 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -15,7 +15,6 @@ from trinity.buffer.buffer import get_buffer_reader from trinity.common.config import Config from trinity.common.constants import ( - EXPLORER_NAME, ROLLOUT_WEIGHT_SYNC_GROUP_NAME, RunningStatus, SyncMethod, @@ -47,6 +46,7 @@ def __init__(self, config: Config): self.config.buffer.explorer_output, # type: ignore self.config.buffer, ) + self.experience_buffer.acquire() self.config.buffer.explorer_input.taskset.index = explorer_meta.get("latest_task_index", 0) self.taskset = get_buffer_reader( self.config.buffer.explorer_input.taskset, self.config.buffer @@ -55,7 +55,7 @@ def __init__(self, config: Config): self.monitor = MONITOR.get(self.config.monitor.monitor_type)( project=self.config.project, name=self.config.name, - role=EXPLORER_NAME, + role=self.config.explorer.name, config=config, ) self.batch_size = config.buffer.batch_size @@ -100,6 +100,7 @@ async def setup_weight_sync_group( + base_offset, world_size=world_size, group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME, + explorer_name=self.config.explorer.name, timeout=self.config.synchronizer.sync_timeout, update_with_checkpoint=self.use_checkpoint_weights_update, state_dict_meta=state_dict_meta, @@ -184,7 +185,7 @@ async def explore(self) -> str: self.logger.error(f"Error in Explorer: {e}") break self.logger.info("--------------------\n> Explorer finished.\n--------------------") - return EXPLORER_NAME + return self.config.explorer.name def explore_step(self) -> bool: algo_config = self.algorithm_manager.get_current_algorithm_config(self.explore_step_num + 1) @@ -202,7 +203,7 @@ def explore_step(self) -> bool: ) self.status = RunningStatus.STOPPED self.wait_for_workflow_done() - self.experience_buffer.finish() + self.experience_buffer.release() return False self.runner_pool.run_tasks(tasks) self.explore_step_num += 1 diff --git a/trinity/manager/manager.py b/trinity/manager/manager.py index baaf1242c3..4af6f28685 100644 --- a/trinity/manager/manager.py +++ b/trinity/manager/manager.py @@ -14,8 +14,8 @@ class CacheManager: def __init__(self, config: Config, check_config: bool = False): self.cache_dir = config.monitor.cache_dir # type: ignore - self.explorer_meta_path = os.path.join(self.cache_dir, "explorer_meta.json") # type: ignore - self.trainer_meta_path = os.path.join(self.cache_dir, "trainer_meta.json") # type: ignore + self.explorer_meta_path = os.path.join(self.cache_dir, f"{config.explorer.name}_meta.json") # type: ignore + self.trainer_meta_path = os.path.join(self.cache_dir, f"{config.trainer.name}_meta.json") # type: ignore if check_config: self._check_config_consistency(config) diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 216c916c69..91f681e47f 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -10,12 +10,7 @@ import ray from trinity.common.config import Config -from trinity.common.constants import ( - EXPLORER_NAME, - TRAINER_NAME, - RunningStatus, - SyncMethod, -) +from trinity.common.constants import RunningStatus, SyncMethod from trinity.utils.log import get_logger @@ -45,7 +40,7 @@ def train(self) -> str: self.logger.error(f"Error in Trainer: {e}") break self.logger.info("--------------------\n> Trainer finished.\n--------------------") - return TRAINER_NAME + return self.config.trainer.name def train_step(self) -> bool: """Train one step. @@ -63,7 +58,7 @@ def sync_weight(self) -> None: """Sync the model weight.""" if self.config.synchronizer.sync_method == SyncMethod.NCCL: if self.explorer_ref is None: - self.explorer_ref = ray.get_actor(EXPLORER_NAME) + self.explorer_ref = ray.get_actor(self.config.explorer.name) explorer_status = ray.get(self.explorer_ref.running_status.remote()) if explorer_status == RunningStatus.STOPPED: self.logger.warning("Explorer has already stopped. Skipping sync weight.") diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index cbc88902a0..5e76375315 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -71,11 +71,7 @@ from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager from trinity.common.config import AlgorithmConfig -from trinity.common.constants import ( - EXPLORER_NAME, - ROLLOUT_WEIGHT_SYNC_GROUP_NAME, - SyncMethod, -) +from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, SyncMethod from trinity.utils.distributed import init_process_group, is_ipv6_address logger = logging.getLogger(__file__) @@ -577,7 +573,7 @@ def setup_weight_sync_group(self): master_address, master_port = self.get_availale_master_addr_port() world_size = self.config.synchronizer.explorer_world_size + 1 print(f"Trainer init_process_group {master_address}:{master_port} ({world_size}).") - explorer = ray.get_actor(EXPLORER_NAME) + explorer = ray.get_actor(self.config.explorer_name) setup_ref = explorer.setup_weight_sync_group.remote( master_address, master_port, self.state_dict_meta ) diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index d041bea128..7c789a98d2 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -36,7 +36,6 @@ from trinity.algorithm.algorithm_manager import AlgorithmManager from trinity.algorithm.utils import prefix_metrics from trinity.common.config import Config -from trinity.common.constants import TRAINER_NAME from trinity.common.experience import Experiences from trinity.trainer.trainer import TrainEngineWrapper from trinity.utils.monitor import MONITOR @@ -150,7 +149,7 @@ def __init__( self.logger = MONITOR.get(global_config.monitor.monitor_type)( project=config.trainer.project_name, name=config.trainer.experiment_name, - role=TRAINER_NAME, + role=global_config.trainer.name, config=global_config, ) self.reset_experiences_example_table() diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index f12a854335..965fb7e4df 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -69,7 +69,7 @@ def calculate_metrics( @MONITOR.register_module("tensorboard") class TensorboardMonitor(Monitor): def __init__(self, project: str, name: str, role: str, config: Config = None) -> None: - self.tensorboard_dir = os.path.join(config.monitor.cache_dir, "tensorboard") + self.tensorboard_dir = os.path.join(config.monitor.cache_dir, "tensorboard", role) os.makedirs(self.tensorboard_dir, exist_ok=True) self.logger = SummaryWriter(self.tensorboard_dir) self.console_logger = get_logger(__name__)