torch hub load
时间: 2025-07-02 19:08:14 浏览: 10
### 如何使用 `torch.hub.load` 加载预训练模型
通过 PyTorch 提供的 `torch.hub.load()` 方法可以方便地加载来自 GitHub 或本地目录中的预训练模型。以下是具体方法及其细节:
#### 基本语法
函数签名如下:
```python
torch.hub.load(repo_or_dir, model, *args, **kwargs)
```
- `repo_or_dir`: 表示存储库的位置,可以是一个 GitHub 存储库名称(如 `'pytorch/vision'`),也可以是本地文件夹路径。
- `model`: 要加载的具体模型名称,通常由目标仓库提供支持。
- `*args`, `**kwargs`: 传递给模型初始化函数的额外参数。
此函数最终会返回一个已实例化的模块对象[^4]。
#### 实际案例演示
下面展示如何利用该功能下载并运行 ResNet50 预训练版本的例子:
```python
import torch
# 设置验证标志位为真以绕过分支检测错误
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
# 使用 PyTorch 官方视觉库加载 resnet50 模型,并附带预训练权重
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
# 将模型切换至评估模式以及 GPU 设备上执行推理操作
model.eval()
model.to("cuda")
```
在此基础上还可以进一步借助其他工具比如 Torch-TensorRT 对其性能加以优化[^2]。
#### 关于模型保存与恢复
当完成上述过程之后如果打算长期保留所获取到的结果,则需考虑妥善处理相关数据结构。一般情况下我们仅储存必要的状态字典部分而非整个类定义本身;因此,在另一次启动程序时重新构建相同的架构变得至关重要[^3]。
例如先创建对应的网络实体后再导入先前记录下的数值即可实现无缝衔接:
```python
from my_custom_networks import CustomResNet50
loaded_model = CustomResNet50()
checkpoint = torch.load('path_to_checkpoint_file')
loaded_model.load_state_dict(checkpoint['model_state'])
```
以上即完成了基于 PyTorch HUB 的标准流程介绍。
阅读全文
相关推荐


















