TensorFlow 中文手册

10.2 数据集的创建

TensorFlow数据集创建完整教程 - 从张量到文件与生成器

TensorFlow 中文手册

本文作为TensorFlow中文学习手册的一部分,详细介绍如何创建数据集,包括从张量或数组使用from_tensor_slices方法、从CSV/TFRecord/图像文件(后续详解)、使用from_generator适配自定义数据,以及数据集的基本结构如特征+标签和多输入/多输出数据集,适合新手快速上手。

推荐工具
PyCharm专业版开发必备

功能强大的Python IDE,提供智能代码补全、代码分析、调试和测试工具,提高Python开发效率。特别适合处理列表等数据结构的开发工作。

了解更多

TensorFlow数据集创建

引言

在TensorFlow中,数据集(Dataset)是处理大量数据的核心工具,它提供了高效的数据加载、预处理和流水线操作,能够帮助模型训练更快速和可扩展。本章节将介绍几种常见的数据集创建方法,以及数据集的基本结构,让你轻松入门。

1. 从张量或数组创建数据集

TensorFlow提供了tf.data.Dataset.from_tensor_slices方法来直接从Python张量或NumPy数组创建数据集。这对于小型或内存中已有的数据非常方便。

示例代码

import tensorflow as tf
import numpy as np

# 创建样本数据
features = np.array([[1, 2], [3, 4], [5, 6]])  # 特征数据,形状(3, 2)
labels = np.array([0, 1, 0])  # 标签数据,形状(3,)

# 使用from_tensor_slices创建数据集
dataset = tf.data.Dataset.from_tensor_slices((features, labels))

# 查看数据集元素
for feature, label in dataset.take(3):
    print(f"特征: {feature.numpy()}, 标签: {label.numpy()}")

说明from_tensor_slices将数据切片成多个元素,每个元素对应一个样本。在示例中,数据集的每个元素包含一个特征和对应的标签。

2. 从文件创建数据集

对于大规模数据,通常从文件(如CSV、TFRecord或图像文件)中读取,以减少内存占用。TensorFlow支持多种文件格式,这里简要介绍,后续章节会详细展开。

  • CSV文件:使用tf.data.experimental.CsvDataset或相关API,适合表格数据。
  • TFRecord文件:TensorFlow的高效二进制格式,适用于大数据集,提供tf.data.TFRecordDataset
  • 图像文件:通过tf.data.Dataset.list_filestf.io.decode_image等处理图像数据。

提示

我们将在后续章节深入探讨如何从文件创建数据集,包括具体代码示例和最佳实践。

3. 生成器创建数据集

当数据生成逻辑复杂或数据源不支持直接加载时,可以使用tf.data.Dataset.from_generator。这允许你定义自定义Python生成器函数,动态生成数据。

示例代码

def custom_generator():
    """自定义生成器函数,模拟数据流"""
    for i in range(5):
        yield np.random.rand(2), i  # 返回特征和标签

# 使用from_generator创建数据集
dataset = tf.data.Dataset.from_generator(
    custom_generator,
    output_signature=(
        tf.TensorSpec(shape=(2,), dtype=tf.float32),  # 特征签名
        tf.TensorSpec(shape=(), dtype=tf.int32)      # 标签签名
    )
)

# 查看数据集
for feature, label in dataset:
    print(f"特征: {feature.numpy()}, 标签: {label.numpy()}")

说明from_generator需要指定生成器函数和输出签名(signature),以定义数据形状和类型。这适配了各种自定义数据场景。

4. 数据集的基本结构

数据集通常组织为特征(features)和标签(labels)的组合,用于监督学习。TensorFlow支持多种结构。

特征和标签

  • 简单结构:数据集中每个元素是一个元组(feature, label),如前面的示例。
  • 字典结构:也可以使用字典,例如{'input': feature, 'label': label},这在多输入场景中常见。

多输入/多输出数据集

当模型有多个输入(如文本和图像)或多个输出时,数据集可以适配:

# 假设多输入数据:特征1和特征2
features1 = np.array([[1, 2], [3, 4]])
features2 = np.array([[5], [6]])
labels = np.array([0, 1])

# 创建数据集,元素为元组((特征1, 特征2), 标签)
dataset = tf.data.Dataset.from_tensor_slices(((features1, features2), labels))

for (feat1, feat2), label in dataset:
    print(f"特征1: {feat1.numpy()}, 特征2: {feat2.numpy()}, 标签: {label.numpy()}")

提示:数据集结构应与模型输入输出匹配,确保数据流顺畅。

总结

本章介绍了TensorFlow中创建数据集的三种主要方法:从张量/数组、文件和生成器,以及数据集的基本结构。选择合适的方法取决于数据来源和规模,后续章节将深入文件处理等进阶主题。实践这些示例,你将能轻松构建自己的数据集流程。


下一步:探索数据集操作,如洗牌(shuffle)、批处理(batch)和转换,以优化训练性能。

开发工具推荐
Python开发者工具包

包含虚拟环境管理、代码格式化、依赖管理、测试框架等Python开发全流程工具,提高开发效率。特别适合处理复杂数据结构和算法。

获取工具包