使用FARM框架实现文档分类的保留集验证方法

使用FARM框架实现文档分类的保留集验证方法

概述

本文将详细介绍如何使用FARM框架中的保留集(holdout)验证方法进行文档分类任务。保留集验证是交叉验证的一种形式,特别适用于数据量有限的情况,能够更准确地评估模型性能。

环境准备

首先需要设置日志记录和实验跟踪系统。代码中使用MLFlow来记录实验过程,方便后续分析和比较不同实验的结果。

logger = logging.getLogger(__name__)
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO)

mlflogger = MLFlowLogger(tracking_uri="logs")
mlflogger.init_experiment(experiment_name="Example-docclass-xval", run_name="testrun1")

核心参数配置

保留集验证的关键参数包括:

  • holdout_splits: 设置保留集的分割次数(默认为5)
  • holdout_train_split: 训练集比例(默认为0.8)
  • holdout_stratification: 是否进行分层抽样(保持类别分布)
holdout_splits = 5
holdout_train_split = 0.8
holdout_stratification = True

数据处理流程

1. 数据处理器(Processor)创建

使用TextClassificationProcessor处理原始文本数据,主要功能包括:

  • 文本分词
  • 序列长度截断
  • 标签转换
  • 数据集分割
processor = TextClassificationProcessor(
    tokenizer=tokenizer,
    max_seq_len=64,
    data_dir=Path("../data/germeval18"),
    label_list=label_list,
    metric=metric,
    dev_split=dev_split,
    dev_stratification=dev_stratification,
    label_column_name="coarse_label"
)

2. 数据仓库(DataSilo)构建

DataSilo负责:

  • 加载处理后的数据集
  • 提供数据加载器(DataLoader)
  • 计算数据集统计信息
data_silo = DataSilo(
    processor=processor,
    batch_size=batch_size)

3. 保留集分割

使用DataSiloForHoldout.make方法创建保留集分割:

silos = DataSiloForHoldout.make(
    data_silo,
    sets=["train", "dev"],
    n_splits=holdout_splits,
    train_split=holdout_train_split,
    stratification=holdout_stratification)

模型训练与评估

1. 模型构建

构建自适应模型(AdaptiveModel),包含:

  • 预训练语言模型(如BERT)
  • 文本分类预测头
language_model = LanguageModel.load(lang_model)
prediction_head = TextClassificationHead(
    class_weights=data_silo.calculate_class_weights(task_name="text_classification"),
    num_labels=len(label_list))

model = AdaptiveModel(
    language_model=language_model,
    prediction_heads=[prediction_head],
    embeds_dropout_prob=0.2,
    lm_output_types=["per_sequence"],
    device=device)

2. 早停机制(EarlyStopping)

设置早停机制防止过拟合:

earlystopping = EarlyStopping(
    metric="f1_offense", 
    mode="max",
    save_dir=save_dir,
    patience=5)

3. 训练过程

使用Trainer进行模型训练:

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    data_silo=silo_to_use,
    epochs=n_epochs,
    n_gpu=n_gpu,
    lr_schedule=lr_schedule,
    evaluate_every=evaluate_every,
    device=device,
    early_stopping=earlystopping,
    evaluator_test=False)

评估指标

自定义评估指标函数mymetrics,包含:

  • 准确率(Accuracy)
  • 各类别的F1分数
  • 宏观/微观平均F1
  • Matthews相关系数(MCC)
def mymetrics(preds, labels):
    acc = simple_accuracy(preds, labels).get("acc")
    f1other = f1_score(y_true=labels, y_pred=preds, pos_label="OTHER")
    f1offense = f1_score(y_true=labels, y_pred=preds, pos_label="OFFENSE")
    f1macro = f1_score(y_true=labels, y_pred=preds, average="macro")
    f1micro = f1_score(y_true=labels, y_pred=preds, average="micro")
    mcc = matthews_corrcoef(labels, preds)
    return {
        "acc": acc,
        "f1_other": f1other,
        "f1_offense": f1offense,
        "f1_macro": f1macro,
        "f1_micro": f1micro,
        "mcc": mcc
    }

结果分析与汇总

在所有保留集分割完成后,计算各指标的均值和标准差:

eval_metric = {}
eval_metric["task_name"] = allresults[0][0].get("task_name", "UNKNOWN TASKNAME")
for name in eval_metric_lists_head0.keys():
    values = eval_metric_lists_head0[name]
    vmean = statistics.mean(values)
    vstdev = statistics.stdev(values)
    eval_metric[name+"_mean"] = vmean
    eval_metric[name+"_stdev"] = vstdev

最佳模型测试

使用性能最好的模型在原始测试集上进行最终评估:

evaluator_origtest = Evaluator(
    data_loader=data_silo.get_data_loader("test"),
    tasks=data_silo.processor.tasks,
    device=device
)

model = AdaptiveModel.load(save_dir, device, lm_name=lm_name)
model.connect_heads_with_processor(data_silo.processor.tasks, require_labels=True)
result = evaluator_origtest.eval(model)

总结

本文详细介绍了使用FARM框架进行文档分类任务的保留集验证方法。保留集验证能够更准确地评估模型性能,特别适用于数据量有限的情况。通过自定义评估指标、早停机制和多轮验证,可以构建出性能稳定、泛化能力强的文本分类模型。

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

石淞畅Oprah

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值