TensorFlow 中文手册

19.1 模型保存与加载的三种方式

TensorFlow 模型保存与加载的三种方式详解 | TensorFlow 中文学习手册

TensorFlow 中文手册

本章节详细讲解 TensorFlow 中模型保存与加载的三种主要方式:SavedModel 格式保存整个模型、HDF5 格式保存权重、以及 JSON/YAML 格式保存模型结构。通过对比适用场景和优缺点,帮助初学者快速掌握如何根据需求选择最佳保存方式。

推荐工具
PyCharm专业版开发必备

功能强大的Python IDE,提供智能代码补全、代码分析、调试和测试工具,提高Python开发效率。特别适合处理列表等数据结构的开发工作。

了解更多

TensorFlow 模型保存与加载的三种方式

引言

在 TensorFlow 中,训练完成的模型通常需要保存下来,以便后续使用,例如部署到生产环境、共享给他人或恢复训练进度。保存模型是深度学习工作流程中的关键步骤。TensorFlow 提供了多种保存方式,每种方式都有其独特的优势和适用场景。本章将详细介绍三种常用的保存方式,并对它们进行对比,帮助你轻松入门。

1. 保存整个模型(SavedModel 格式,推荐)

SavedModel 是 TensorFlow 官方推荐的格式,因为它能够保存模型的完整信息,包括结构、权重和优化器状态。这就像一个“快照”,可以完整地恢复模型。

如何保存

使用 model.save() 方法即可保存为 SavedModel 格式。默认情况下,它会保存整个模型。

import tensorflow as tf

# 假设已经构建并训练了一个模型
model = tf.keras.Sequential([
    tf.keras.layers.Dense(10, input_shape=(5,)),
    tf.keras.layers.Dense(1)
])
model.compile(optimizer='adam', loss='mse')
model.fit(x_train, y_train, epochs=5)  # 假设有训练数据

# 保存整个模型
model.save('my_model')  # 这会创建一个名为 'my_model' 的文件夹

如何加载

加载时,使用 tf.keras.models.load_model() 方法,可以恢复所有信息。

# 加载整个模型
loaded_model = tf.keras.models.load_model('my_model')

# 可以直接使用,包括继续训练或预测
loaded_model.fit(x_new_data, y_new_data, epochs=5)  # 如果需要继续训练
predictions = loaded_model.predict(x_test)  # 进行预测

优点

  • 完整性:保存结构、权重和优化器状态,可以无缝恢复训练或部署。
  • 方便:加载后无需额外步骤,直接可用。
  • 标准格式:是 TensorFlow 的通用格式,适用于跨平台部署。

缺点

  • 文件较大:由于包含所有数据,SavedModel 文件夹可能占用较多存储空间。

适用场景

  • 生产环境部署,例如使用 TensorFlow Serving。
  • 需要完整恢复训练状态的情况,如中断后继续训练。
  • 共享模型给他人,确保对方能直接使用。

2. 保存权重(HDF5 格式)

如果你只关心模型的权重(即参数),可以使用 HDF5 格式保存权重,而结构则单独处理。这可以节省存储空间,但需要重建模型结构。

如何保存

使用 model.save_weights() 方法保存权重为 HDF5 文件。

# 保存权重
model.save_weights('model_weights.h5')

如何加载

加载权重前,需要重建相同的模型结构。

# 首先重建模型结构(假设和保存时相同)
new_model = tf.keras.Sequential([
    tf.keras.layers.Dense(10, input_shape=(5,)),
    tf.keras.layers.Dense(1)
])
new_model.compile(optimizer='adam', loss='mse')

# 加载权重
new_model.load_weights('model_weights.h5')

# 现在模型可用于预测或进一步训练
predictions = new_model.predict(x_test)

优点

  • 文件较小:只保存权重,减少存储开销。
  • 灵活性:可以独立于结构更新或共享权重。

缺点

  • 依赖结构:加载时必须使用完全相同的模型结构,否则可能导致错误。
  • 不保存优化器状态:无法直接恢复训练进度。

适用场景

  • 预训练模型的权重分享,例如迁移学习中只使用权重。
  • 当你只想保存和加载参数,而不需要完整模型时。
  • 为了版本控制或比较不同训练阶段的权重。

3. 保存模型结构(JSON/YAML 格式)

有时,你可能只想保存模型的结构定义,而不包含权重。这在快速原型设计或文档化时很有用。TensorFlow 支持将结构保存为 JSON 或 YAML 格式。

如何保存

使用 model.to_json()model.to_yaml() 方法获取结构字符串,然后写入文件。

# 保存为 JSON
model_json = model.to_json()
with open('model_structure.json', 'w') as json_file:
    json_file.write(model_json)

# 保存为 YAML(类似方式)
model_yaml = model.to_yaml()
with open('model_structure.yaml', 'w') as yaml_file:
    yaml_file.write(model_yaml)

如何加载

加载结构后,需要结合权重文件来重建模型。

# 从 JSON 加载结构
from tensorflow.keras.models import model_from_json

with open('model_structure.json', 'r') as json_file:
    loaded_model = model_from_json(json_file.read())

# 然后加载权重(假设有对应的权重文件)
loaded_model.load_weights('model_weights.h5')

# 编译模型(如果需要训练)
loaded_model.compile(optimizer='adam', loss='mse')

优点

  • 轻量:文件非常小,便于共享或版本控制。
  • 易于理解:JSON 或 YAML 格式是文本文件,便于人类阅读和编辑。

缺点

  • 不完整:只包含结构,没有权重或优化器状态。
  • 需要额外步骤:必须单独保存和加载权重才能使用模型。

适用场景

  • 模型设计的文档化,便于团队协作。
  • 快速重建或修改模型结构,而不重复定义代码。
  • 用于生成模型图或可视化结构。

适用场景对比

为了更直观地理解三种方式的区别,以下是一个简单对比表:

保存方式 保存内容 文件大小 加载难度 适用场景
SavedModel 格式 结构 + 权重 + 优化器状态 简单(直接加载) 生产部署、完整恢复训练
HDF5 权重 仅权重 中等(需重建结构) 预训练模型分享、权重更新
JSON/YAML 结构 仅结构 很小 高(需结合权重) 结构设计、文档化、快速原型

关键点总结

  • SavedModel:最完整,适合大多数生产和使用场景。
  • HDF5 权重:如果你只关心参数,这是节省空间的选择。
  • JSON/YAML 结构:纯结构保存,便于灵活修改和分享设计。

总结

选择哪种保存方式取决于你的具体需求:

  • 如果你的目标是快速部署或无缝恢复训练,SavedModel 格式是最佳选择。
  • 如果只存储权重用于迁移学习或节省空间,使用 HDF5 格式
  • 如果需要共享或备份模型结构定义,JSON 或 YAML 格式更合适。

在实际项目中,通常会结合使用这些方式:例如,保存完整模型用于部署,同时保存权重和结构用于备份或协作。通过本章学习,你应该能够根据场景灵活选择保存方式,有效管理 TensorFlow 模型。

下一章,我们将探讨如何优化和评估训练好的模型,以进一步提升性能。祝你学习愉快!

开发工具推荐
Python开发者工具包

包含虚拟环境管理、代码格式化、依赖管理、测试框架等Python开发全流程工具,提高开发效率。特别适合处理复杂数据结构和算法。

获取工具包