Python 机器学习中,模型保存和加载是两个非常重要的操作。模型保存可以将训练好的模型保存到文件,以便以后使用。模型加载可以将保存的文件加载到内存,以便进行预测或评估。最常用保存和加模型的库包括pickle和joblib,另外在使用特定的机器学习库,如scikit-learn、TensorFlow或PyTorch时,它们也提供了自己的保存和加载机制。
参考文档:Python 机器学习 模型保存和加载-CJavaPy
1、pickle
pickle
模块是Python的一部分,提供了一个简单的方式来序列化和反序列化一个Python对象结构。训练好的模型通常需要被保存,以便于未来进行预测时能够直接加载使用,而不需要重新训练。pickle模块是Python中一个常用的进行对象序列化和反序列化的模块,它可以将Python对象转换为字节流,从而能够将对象保存到文件中,或者从文件中恢复对象。
1)模型的保存
pickle.dump()
方法用于将Python对象序列化并保存到文件中。常用参数如下,
参数 |
类型 |
描述 |
obj |
对象 |
要被序列化的Python对象。 |
file |
文件对象 |
一个打开的文件对象, 必须以二进制写模式打开('wb')。 |
protocol |
整数/None |
指定pickle数据格式的版本号。 如果省略,则使用默认的协议。可选的协议版本号从0到5, 其中更高的版本 提供了更高的效率和新的功能。 |
fix_imports |
布尔值 |
仅在Python 2和Python 3之间的互操作性中使用。 默认为True, 为了使pickle文件在不同的Python版本间 能够互相兼容。 |
buffer_callback |
回调函数/None |
一个可选的回调函数, 用于pickle协议版本5中, 为了提供对大型数据的优化处理机制。 仅在Python 3.8及以上版本中可用。 |
使用代码:
import pickle
# 创建一个复杂的数据结构
my_data = {
'name': 'Python',
'version': 3.8,