TensorFlow 中文手册

11.3 TFRecord 文件的写入与读取

TensorFlow TFRecord 文件操作教程:写入读取数据集压缩

TensorFlow 中文手册

本教程详细介绍TensorFlow中TFRecord文件的写入与读取方法,包括从原始数据生成单文件和多个文件、创建tf.data.Dataset数据集,以及通过压缩优化存储的完整指南,适合新人学习。

推荐工具
PyCharm专业版开发必备

功能强大的Python IDE,提供智能代码补全、代码分析、调试和测试工具,提高Python开发效率。特别适合处理列表等数据结构的开发工作。

了解更多

TensorFlow TFRecord 文件:写入、读取与优化存储

TFRecord 是 TensorFlow 中用于高效存储序列化数据的二进制格式,特别适合大规模机器学习任务,因为它能加速数据加载并减少存储空间。本教程将逐步介绍 TFRecord 文件的操作,从基础写入到高级优化。

1. TFRecord 文件的写入:从原始数据生成

写入 TFRecord 文件是将原始数据(如文本、图像、数值)转换为序列化格式的过程。TensorFlow 提供了 tf.train.Exampletf.train.Feature 来构建数据条目。

单文件写入

将数据写入单个 TFRecord 文件,适用于数据量较小的场景。

import tensorflow as tf

# 示例:写入一个包含整数和字符串的数据
def write_tfrecord_single(output_path, data_list):
    with tf.io.TFRecordWriter(output_path) as writer:
        for data in data_list:
            # 创建一个 Example
            feature = {
                'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[data['label']])),
                'text': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data['text'].encode()]))
            }
            example = tf.train.Example(features=tf.train.Features(feature=feature))
            writer.write(example.SerializeToString())

# 使用示例
data = [{'label': 1, 'text': 'hello'}, {'label': 2, 'text': 'world'}]
write_tfrecord_single('data.tfrecord', data)

多文件写入

对于大数据集,可以写入多个 TFRecord 文件以提高并行处理效率。

def write_tfrecord_multi(output_prefix, data_list, num_shards=2):
    # 将数据分配到多个文件中
    for shard_id in range(num_shards):
        output_path = f'{output_prefix}_{shard_id}.tfrecord'
        with tf.io.TFRecordWriter(output_path) as writer:
            # 写入每个分片的数据
            for i in range(shard_id, len(data_list), num_shards):
                feature = {
                    'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[data_list[i]['label']])),
                    'text': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data_list[i]['text'].encode()]))
                }
                example = tf.train.Example(features=tf.train.Features(feature=feature))
                writer.write(example.SerializeToString())

# 使用示例
write_tfrecord_multi('data_sharded', data, num_shards=2)

2. 从 TFRecord 文件创建 tf.data.Dataset

读取 TFRecord 文件并使用 tf.data.TFRecordDataset 创建高效的数据集,便于训练模型。

def create_dataset(file_pattern, parse_function):
    # 创建数据集
    dataset = tf.data.TFRecordDataset(file_pattern)
    dataset = dataset.map(parse_function)  # 解析数据
    dataset = dataset.shuffle(buffer_size=100).batch(32).prefetch(tf.data.AUTOTUNE)
    return dataset

# 解析函数示例
def parse_example(serialized_example):
    feature_description = {
        'label': tf.io.FixedLenFeature([], tf.int64),
        'text': tf.io.FixedLenFeature([], tf.string)
    }
    parsed_features = tf.io.parse_single_example(serialized_example, feature_description)
    return parsed_features['label'], parsed_features['text']

# 使用示例:读取单文件
single_dataset = create_dataset('data.tfrecord', parse_example)
# 使用示例:读取多文件(使用通配符)
multi_dataset = create_dataset('data_sharded_*.tfrecord', parse_example)

3. TFRecord 文件的压缩与解压缩(优化存储)

TFRecord 文件支持 GZIP 压缩,以减少磁盘占用并可能加速 I/O 操作。

写入压缩文件

在写入时启用压缩。

def write_compressed_tfrecord(output_path, data_list):
    with tf.io.TFRecordWriter(output_path, options=tf.io.TFRecordOptions(compression_type='GZIP')) as writer:
        for data in data_list:
            feature = {
                'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[data['label']])),
                'text': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data['text'].encode()]))
            }
            example = tf.train.Example(features=tf.train.Features(feature=feature))
            writer.write(example.SerializeToString())

# 使用示例
write_compressed_tfrecord('data_compressed.tfrecord.gz', data)

读取压缩文件

读取时指定相同的压缩类型。

def read_compressed_dataset(file_path):
    dataset = tf.data.TFRecordDataset(file_path, compression_type='GZIP')
    dataset = dataset.map(parse_example)  # 使用相同的解析函数
    return dataset

# 使用示例
compressed_dataset = read_compressed_dataset('data_compressed.tfrecord.gz')

总结与最佳实践

  • 写入: 对于小数据使用单文件,大数据使用多文件分片以提高效率。
  • 读取: 使用 tf.data.TFRecordDataset 和解析函数创建可迭代数据集,结合 shuffle 和 batch 优化性能。
  • 压缩: 启用 GZIP 压缩可以显著减少存储空间,特别是在处理文本或重复数据时。
  • 注意事项: 确保写入和读取时的特征定义一致;压缩可能会轻微增加 CPU 开销,但通常利大于弊。

通过本教程,您应该能够掌握 TFRecord 文件的基本操作,从而在 TensorFlow 项目中高效管理数据。

开发工具推荐
Python开发者工具包

包含虚拟环境管理、代码格式化、依赖管理、测试框架等Python开发全流程工具,提高开发效率。特别适合处理复杂数据结构和算法。

获取工具包