24.1 计算机视觉:自定义图像分类系统(迁移学习 + Flask)
TensorFlow 中文学习手册:图像分类系统构建教程(迁移学习 + Flask)
本章节详细讲解如何使用 TensorFlow 构建自定义图像分类系统,涵盖业务需求、数据集预处理、迁移学习模型构建、训练调优和 Flask 本地 API 部署。适合深度学习新手快速上手,并提供完整代码示例。
计算机视觉:构建自定义图像分类系统(迁移学习 + Flask)
引言
欢迎来到 TensorFlow 学习手册的计算机视觉章节!本章将带您从头构建一个自定义的图像分类系统。我们将利用迁移学习技术,结合 MobileNetV2 预训练模型,并通过 Flask 框架将模型封装为本地 API。无论您是新手还是想巩固知识,这个教程都会逐步引导您完成猫狗分类、花卉分类或商品分类等实际项目。
1. 业务需求
图像分类在现实生活中有广泛的应用场景,例如:
- 猫狗分类:自动识别用户上传的宠物照片,区分猫和狗。
- 花卉分类:开发植物识别应用,帮助用户识别不同种类的花卉。
- 商品分类:在电商平台中,自动将商品图片分类到相应类别,提高管理效率。
这些需求都可以通过构建一个自定义图像分类系统来实现。迁移学习让我们能快速高效地训练模型,即使数据集较小。
2. 数据集准备与预处理
图像加载
在 TensorFlow 中,我们通常使用 tf.keras.preprocessing.image 模块加载图像。例如,假设您有一个包含猫狗图片的文件夹,可以使用以下方式加载:
from tensorflow.keras.preprocessing import image
import numpy as np
img_path = 'path/to/your/image.jpg'
img = image.load_img(img_path, target_size=(224, 224)) # 调整大小为 MobileNetV2 要求的 224x224
img_array = image.img_to_array(img) # 转换为数组
img_array = np.expand_dims(img_array, axis=0) # 添加批次维度
对于大型数据集,推荐使用 tf.data API 进行高效加载和预处理。
数据增强
数据增强是防止过拟合的重要手段。TensorFlow 提供了 ImageDataGenerator 来轻松实现。
from tensorflow.keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(
rotation_range=30,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
# 使用生成器加载数据集
train_generator = datagen.flow_from_directory(
'train_dir', # 训练图像目录
target_size=(224, 224),
batch_size=32,
class_mode='categorical'
)
TFRecord 转换(可选)
对于超大数据集,使用 TFRecord 格式可以提升训练性能。TFRecord 是一种高效的二进制文件格式。您可以先将图像和标签转换为 TFRecord 文件,然后用 tf.data.TFRecordDataset 加载。
import tensorflow as tf
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# 示例:将图像写入 TFRecord
with tf.io.TFRecordWriter('images.tfrecord') as writer:
for img, label in dataset:
feature = {
'image': _bytes_feature(tf.io.encode_jpeg(img).numpy()),
'label': _int64_feature(label)
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
3. 迁移学习模型构建
迁移学习让我们可以利用在大型数据集(如 ImageNet)上预训练的模型,快速适应新任务。这里我们使用 MobileNetV2,因为它轻量且适合移动部署。
加载 MobileNetV2 预训练模型
首先,加载 MobileNetV2 模型,并冻结其权重以防止在初始训练中更新。
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
# 加载预训练模型,不包括顶层(分类层)
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
base_model.trainable = False # 冻结基础模型,只训练添加的新层
添加自定义分类头
根据您的分类任务(如猫狗分类有 2 个类别),添加新的全连接层。
# 获取基础模型的输出
x = base_model.output
x = GlobalAveragePooling2D()(x) # 全局平均池化,减少参数量
x = Dense(128, activation='relu')(x) # 添加一个全连接层
predictions = Dense(num_classes, activation='softmax')(x) # num_classes 是类别数,如 2 表示猫狗分类
# 构建完整模型
model = Model(inputs=base_model.input, outputs=predictions)
4. 模型训练与调优
编译模型
编译模型时,选择合适的优化器和损失函数。
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
训练模型
使用数据生成器进行训练。
history = model.fit(
train_generator,
epochs=10, # 初始训练轮次
validation_data=val_generator # 验证集
)
微调
训练一些轮次后,可以解冻部分基础模型层进行微调,以提高性能。
base_model.trainable = True # 解冻所有层
# 或者只解冻最后几层
for layer in base_model.layers[:100]:
layer.trainable = False
# 重新编译模型,使用较低的学习率
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_generator, epochs=5, validation_data=val_generator) # 继续训练
早停
使用 EarlyStopping 回调防止过拟合。
from tensorflow.keras.callbacks import EarlyStopping
early_stopping = EarlyStopping(monitor='val_loss', patience=3) # 如果验证损失连续 3 轮不改善,则停止训练
model.fit(train_generator, epochs=20, validation_data=val_generator, callbacks=[early_stopping])
模型保存
训练完成后,保存模型以便后续使用。
model.save('my_image_classifier.h5') # 保存为 HDF5 格式
5. 模型轻量化与 Flask 封装
模型轻量化(可选)
如果您计划在移动设备上部署,可以使用 TensorFlow Lite 将模型转换为轻量格式。
import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
Flask 封装
创建一个简单的 Flask 应用,将模型部署为本地 API。
from flask import Flask, request, jsonify
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import numpy as np
app = Flask(__name__)
model = load_model('my_image_classifier.h5') # 加载训练好的模型
def preprocess_image(img_file):
"""预处理上传的图像"""
img = image.load_img(img_file, target_size=(224, 224))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array /= 255.0 # 归一化(如果模型训练时使用了)
return img_array
@app.route('/predict', methods=['POST'])
def predict():
if 'image' not in request.files:
return jsonify({'error': 'No image uploaded'}), 400
file = request.files['image']
img_array = preprocess_image(file)
prediction = model.predict(img_array)
class_idx = np.argmax(prediction[0])
confidence = float(prediction[0][class_idx])
# 假设类别标签为 ['cat', 'dog']
labels = ['cat', 'dog']
result = {
'predicted_class': labels[class_idx],
'confidence': confidence
}
return jsonify(result)
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=5000)
运行 Flask 应用后,您可以通过发送 POST 请求到 http://localhost:5000/predict 并附带图像文件来测试分类。
总结
本章详细介绍了使用 TensorFlow 构建自定义图像分类系统的完整流程:从理解业务需求、准备数据集,到构建迁移学习模型、训练调优,最后轻量化并封装为 Flask API。迁移学习让新手也能快速上手,Flask 则提供了便捷的本地部署方式。建议您动手尝试,修改代码以适应自己的分类任务,如花卉或商品分类。祝您学习愉快,早日成为 TensorFlow 高手!