如何使用 Pytorch Lightning 启用早停机制

【PL 基础】如何启用早停机制

摘要

  本文介绍了两种在 PyTorch Lightning 中实现早停机制的方法。第一种是通过重写on_train_batch_start()方法手动控制训练流程;第二种是使用内置的EarlyStopping回调,可以监控验证指标并在指标停止改善时自动停止训练。文章详细说明了EarlyStopping的参数设置,包括监控指标、模式选择、耐心值等核心参数,以及停止阈值、发散阈值等进阶参数。同时介绍了如何通过子类化修改早停触发时机,并提醒注意验证频率与耐心值的配合使用。文末提供了完整的代码示例,展示了如何在实际训练中配置和使用早停机制。

1. on_train_batch_start()

  通过重写 on_train_batch_start() 方法,在满足特定条件时提前返回,从而停止并跳过当前epoch的剩余训练批次。

  如果对于最初要求的每个epoch重复这样做,将停止整个训练。

2. EarlyStopping Callback

  EarlyStopping 回调可用于监控指标,并在没有观察到改善时停止训练。

要启用此功能,请执行以下操作:

  • 导入 EarlyStopping 回调模块;

  • 使用 log() 方法记录需要监控的指标;

  • 初始化回调并设置要监控的指标名称(monitor 参数);

  • 根据指标特性设置监控模式(mode 参数);

  • EarlyStopping 回调传递给 Trainercallbacks 参数。

from lightning.pytorch.callbacks.early_stopping import EarlyStopping


class LitModel(LightningModule):
    def validation_step(self, batch, batch_idx):
        loss = ...
        self.log("val_loss", loss)


model = LitModel()
trainer = Trainer(callbacks=[EarlyStopping(monitor="val_loss", mode="min")])
trainer.fit(model)

可以通过更改其参数来自定义回调行为。

early_stop_callback = EarlyStopping(monitor="val_accuracy", min_delta=0.00, 
                                    patience=3, verbose=False, mode="max")
trainer = Trainer(callbacks=[early_stop_callback])

用于在极值点停止训练的附加参数:

  • stopping_threshold(停止阈值):当监控指标达到该阈值时立即终止训练。适用于已知超过特定最优值后模型不再提升的场景。

  • divergence_threshold(发散阈值):当监控指标劣化至该阈值时即刻停止训练。当指标恶化至此程度时,我们认为模型已无法恢复,此时应提前终止并尝试不同初始条件。

  • check_finite(有限值检测):启用后,若监控指标出现NaN(非数值)或无穷大时终止训练。

  • check_on_train_epoch_end(训练周期结束检测):启用后,在训练周期结束时检查指标。仅当监控指标通过周期级训练钩子记录时才需启用此功能。

若需在训练过程的其他阶段启用早停机制,请通过创建子类继承 EarlyStopping 类并修改其调用位置:

class MyEarlyStopping(EarlyStopping):
    def on_validation_end(self, trainer, pl_module):
        # override this to disable early stopping at the end of val loop
        pass

    def on_train_end(self, trainer, pl_module):
        # instead, do it at the end of training loop
        self._run_early_stopping_check(trainer)

默认情况下,EarlyStopping 回调会在每个验证周期结束时触发。但验证频率可通过 Trainer 中的参数调节,例如通过设置 check_val_every_n_epoch(每N个训练周期验证一次)和 val_check_interval(验证间隔)。需特别注意:patience(耐心值)统计的是验证结果未提升的次数,而非训练周期数。因此当设置 check_val_every_n_epoch=10patience=3 时,训练器需经历至少 40个训练周期才会停止。

PyTorch Lightning支持多机多卡的模型训练,可以使用`DDP`(分布式数据并行)模块来实现。 以下是一个简单的多机多卡训练的例子: ```python import pytorch_lightning as pl from pytorch_lightning import Trainer from pytorch_lightning.plugins import DDPPlugin class MyLightningModule(pl.LightningModule): def __init__(self): super().__init__() # 定义模型 def forward(self, x): # 前向传播 def training_step(self, batch, batch_idx): # 定义训练步骤 def configure_optimizers(self): # 定义优化器 # 实例化模型 model = MyLightningModule() # 实例化Trainer trainer = Trainer( gpus=2, # 每台机器使用2个GPU num_nodes=2, # 使用2台机器 accelerator='ddp', # 使用DDP plugins=DDPPlugin(find_unused_parameters=False) # 使用DDP插件并禁用未使用参数的检测 ) # 开始训练 trainer.fit(model) ``` 在这个例子中,我们使用了`Trainer`类来进行模型训练。`gpus`参数指定每台机器使用的GPU数量,`num_nodes`参数指定使用的机器数量,`accelerator`参数指定使用的加速器类型为`ddp`,即使用DDP模式进行分布式训练。 同时,我们使用了`DDPPlugin`插件来启用DDP模式的训练,并且禁用了未使用参数的检测,以避免出现不必要的警告信息。 在实际的多机多卡训练中,需要注意的是,不同机器之间需要进行网络连接,因此需要在运行训练之前进行一些配置工作,以确保不同机器之间的通信正常。同时,还需要注意在训练过程中可能出现的一些问题,如通信延迟、负载均衡等,需要进行适当的调优。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值