在加载预训练权重的时候,有部分代码需要进行进一步的更改,可以采取以下的方式:
# 创建模型
model = create_model(num_classes=args.num_classes)
weights_dict = torch.load(args.weights, map_location=device)
# 删除最后一层的分类层权重(classifier层)
del weights_dict['classifier.1.weight']
del weights_dict['classifier.1.bias']
model.load_state_dict(weights_dict, strict=False)
model.to(device)
# 加载预训练权重
if args.weights:
print(f"Loading weights from {args.weights}")
weights_dict = torch.load(args.weights, map_location=device)
# # 删除与当前模型不匹配的权重(例如分类头)
# # del_keys = ['classifier.3.weight', 'classifier.3.bias']#MobileNetv3
del_keys = [' classifier.1.weight', 'classifier.1.bias']
for key in del_keys:
if key in weights_dict:
del weights_dict[key]
# # 加载剩余的权重
missing_keys, unexpected_keys = model.load_state_dict(weights_dict, strict=False)
print(f"Missing keys: {missing_keys}")
print(f"Unexpected keys: {unexpected_keys}")
# # 将模型移动到设备
model.to(device)