10.2 数据集的创建
TensorFlow数据集创建完整教程 - 从张量到文件与生成器
本文作为TensorFlow中文学习手册的一部分,详细介绍如何创建数据集,包括从张量或数组使用from_tensor_slices方法、从CSV/TFRecord/图像文件(后续详解)、使用from_generator适配自定义数据,以及数据集的基本结构如特征+标签和多输入/多输出数据集,适合新手快速上手。
TensorFlow数据集创建
引言
在TensorFlow中,数据集(Dataset)是处理大量数据的核心工具,它提供了高效的数据加载、预处理和流水线操作,能够帮助模型训练更快速和可扩展。本章节将介绍几种常见的数据集创建方法,以及数据集的基本结构,让你轻松入门。
1. 从张量或数组创建数据集
TensorFlow提供了tf.data.Dataset.from_tensor_slices方法来直接从Python张量或NumPy数组创建数据集。这对于小型或内存中已有的数据非常方便。
示例代码
import tensorflow as tf
import numpy as np
# 创建样本数据
features = np.array([[1, 2], [3, 4], [5, 6]]) # 特征数据,形状(3, 2)
labels = np.array([0, 1, 0]) # 标签数据,形状(3,)
# 使用from_tensor_slices创建数据集
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
# 查看数据集元素
for feature, label in dataset.take(3):
print(f"特征: {feature.numpy()}, 标签: {label.numpy()}")
说明:from_tensor_slices将数据切片成多个元素,每个元素对应一个样本。在示例中,数据集的每个元素包含一个特征和对应的标签。
2. 从文件创建数据集
对于大规模数据,通常从文件(如CSV、TFRecord或图像文件)中读取,以减少内存占用。TensorFlow支持多种文件格式,这里简要介绍,后续章节会详细展开。
- CSV文件:使用
tf.data.experimental.CsvDataset或相关API,适合表格数据。 - TFRecord文件:TensorFlow的高效二进制格式,适用于大数据集,提供
tf.data.TFRecordDataset。 - 图像文件:通过
tf.data.Dataset.list_files和tf.io.decode_image等处理图像数据。
提示
我们将在后续章节深入探讨如何从文件创建数据集,包括具体代码示例和最佳实践。
3. 生成器创建数据集
当数据生成逻辑复杂或数据源不支持直接加载时,可以使用tf.data.Dataset.from_generator。这允许你定义自定义Python生成器函数,动态生成数据。
示例代码
def custom_generator():
"""自定义生成器函数,模拟数据流"""
for i in range(5):
yield np.random.rand(2), i # 返回特征和标签
# 使用from_generator创建数据集
dataset = tf.data.Dataset.from_generator(
custom_generator,
output_signature=(
tf.TensorSpec(shape=(2,), dtype=tf.float32), # 特征签名
tf.TensorSpec(shape=(), dtype=tf.int32) # 标签签名
)
)
# 查看数据集
for feature, label in dataset:
print(f"特征: {feature.numpy()}, 标签: {label.numpy()}")
说明:from_generator需要指定生成器函数和输出签名(signature),以定义数据形状和类型。这适配了各种自定义数据场景。
4. 数据集的基本结构
数据集通常组织为特征(features)和标签(labels)的组合,用于监督学习。TensorFlow支持多种结构。
特征和标签
- 简单结构:数据集中每个元素是一个元组
(feature, label),如前面的示例。 - 字典结构:也可以使用字典,例如
{'input': feature, 'label': label},这在多输入场景中常见。
多输入/多输出数据集
当模型有多个输入(如文本和图像)或多个输出时,数据集可以适配:
# 假设多输入数据:特征1和特征2
features1 = np.array([[1, 2], [3, 4]])
features2 = np.array([[5], [6]])
labels = np.array([0, 1])
# 创建数据集,元素为元组((特征1, 特征2), 标签)
dataset = tf.data.Dataset.from_tensor_slices(((features1, features2), labels))
for (feat1, feat2), label in dataset:
print(f"特征1: {feat1.numpy()}, 特征2: {feat2.numpy()}, 标签: {label.numpy()}")
提示:数据集结构应与模型输入输出匹配,确保数据流顺畅。
总结
本章介绍了TensorFlow中创建数据集的三种主要方法:从张量/数组、文件和生成器,以及数据集的基本结构。选择合适的方法取决于数据来源和规模,后续章节将深入文件处理等进阶主题。实践这些示例,你将能轻松构建自己的数据集流程。
下一步:探索数据集操作,如洗牌(shuffle)、批处理(batch)和转换,以优化训练性能。