使用FARM框架实现文档分类的保留集验证方法
概述
本文将详细介绍如何使用FARM框架中的保留集(holdout)验证方法进行文档分类任务。保留集验证是交叉验证的一种形式,特别适用于数据量有限的情况,能够更准确地评估模型性能。
环境准备
首先需要设置日志记录和实验跟踪系统。代码中使用MLFlow来记录实验过程,方便后续分析和比较不同实验的结果。
logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO)
mlflogger = MLFlowLogger(tracking_uri="logs")
mlflogger.init_experiment(experiment_name="Example-docclass-xval", run_name="testrun1")
核心参数配置
保留集验证的关键参数包括:
holdout_splits
: 设置保留集的分割次数(默认为5)holdout_train_split
: 训练集比例(默认为0.8)holdout_stratification
: 是否进行分层抽样(保持类别分布)
holdout_splits = 5
holdout_train_split = 0.8
holdout_stratification = True
数据处理流程
1. 数据处理器(Processor)创建
使用TextClassificationProcessor
处理原始文本数据,主要功能包括:
- 文本分词
- 序列长度截断
- 标签转换
- 数据集分割
processor = TextClassificationProcessor(
tokenizer=tokenizer,
max_seq_len=64,
data_dir=Path("../data/germeval18"),
label_list=label_list,
metric=metric,
dev_split=dev_split,
dev_stratification=dev_stratification,
label_column_name="coarse_label"
)
2. 数据仓库(DataSilo)构建
DataSilo
负责:
- 加载处理后的数据集
- 提供数据加载器(DataLoader)
- 计算数据集统计信息
data_silo = DataSilo(
processor=processor,
batch_size=batch_size)
3. 保留集分割
使用DataSiloForHoldout.make
方法创建保留集分割:
silos = DataSiloForHoldout.make(
data_silo,
sets=["train", "dev"],
n_splits=holdout_splits,
train_split=holdout_train_split,
stratification=holdout_stratification)
模型训练与评估
1. 模型构建
构建自适应模型(AdaptiveModel),包含:
- 预训练语言模型(如BERT)
- 文本分类预测头
language_model = LanguageModel.load(lang_model)
prediction_head = TextClassificationHead(
class_weights=data_silo.calculate_class_weights(task_name="text_classification"),
num_labels=len(label_list))
model = AdaptiveModel(
language_model=language_model,
prediction_heads=[prediction_head],
embeds_dropout_prob=0.2,
lm_output_types=["per_sequence"],
device=device)
2. 早停机制(EarlyStopping)
设置早停机制防止过拟合:
earlystopping = EarlyStopping(
metric="f1_offense",
mode="max",
save_dir=save_dir,
patience=5)
3. 训练过程
使用Trainer
进行模型训练:
trainer = Trainer(
model=model,
optimizer=optimizer,
data_silo=silo_to_use,
epochs=n_epochs,
n_gpu=n_gpu,
lr_schedule=lr_schedule,
evaluate_every=evaluate_every,
device=device,
early_stopping=earlystopping,
evaluator_test=False)
评估指标
自定义评估指标函数mymetrics
,包含:
- 准确率(Accuracy)
- 各类别的F1分数
- 宏观/微观平均F1
- Matthews相关系数(MCC)
def mymetrics(preds, labels):
acc = simple_accuracy(preds, labels).get("acc")
f1other = f1_score(y_true=labels, y_pred=preds, pos_label="OTHER")
f1offense = f1_score(y_true=labels, y_pred=preds, pos_label="OFFENSE")
f1macro = f1_score(y_true=labels, y_pred=preds, average="macro")
f1micro = f1_score(y_true=labels, y_pred=preds, average="micro")
mcc = matthews_corrcoef(labels, preds)
return {
"acc": acc,
"f1_other": f1other,
"f1_offense": f1offense,
"f1_macro": f1macro,
"f1_micro": f1micro,
"mcc": mcc
}
结果分析与汇总
在所有保留集分割完成后,计算各指标的均值和标准差:
eval_metric = {}
eval_metric["task_name"] = allresults[0][0].get("task_name", "UNKNOWN TASKNAME")
for name in eval_metric_lists_head0.keys():
values = eval_metric_lists_head0[name]
vmean = statistics.mean(values)
vstdev = statistics.stdev(values)
eval_metric[name+"_mean"] = vmean
eval_metric[name+"_stdev"] = vstdev
最佳模型测试
使用性能最好的模型在原始测试集上进行最终评估:
evaluator_origtest = Evaluator(
data_loader=data_silo.get_data_loader("test"),
tasks=data_silo.processor.tasks,
device=device
)
model = AdaptiveModel.load(save_dir, device, lm_name=lm_name)
model.connect_heads_with_processor(data_silo.processor.tasks, require_labels=True)
result = evaluator_origtest.eval(model)
总结
本文详细介绍了使用FARM框架进行文档分类任务的保留集验证方法。保留集验证能够更准确地评估模型性能,特别适用于数据量有限的情况。通过自定义评估指标、早停机制和多轮验证,可以构建出性能稳定、泛化能力强的文本分类模型。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考