11.3 TFRecord 文件的写入与读取
TensorFlow TFRecord 文件操作教程:写入读取数据集压缩
本教程详细介绍TensorFlow中TFRecord文件的写入与读取方法,包括从原始数据生成单文件和多个文件、创建tf.data.Dataset数据集,以及通过压缩优化存储的完整指南,适合新人学习。
推荐工具
TensorFlow TFRecord 文件:写入、读取与优化存储
TFRecord 是 TensorFlow 中用于高效存储序列化数据的二进制格式,特别适合大规模机器学习任务,因为它能加速数据加载并减少存储空间。本教程将逐步介绍 TFRecord 文件的操作,从基础写入到高级优化。
1. TFRecord 文件的写入:从原始数据生成
写入 TFRecord 文件是将原始数据(如文本、图像、数值)转换为序列化格式的过程。TensorFlow 提供了 tf.train.Example 和 tf.train.Feature 来构建数据条目。
单文件写入
将数据写入单个 TFRecord 文件,适用于数据量较小的场景。
import tensorflow as tf
# 示例:写入一个包含整数和字符串的数据
def write_tfrecord_single(output_path, data_list):
with tf.io.TFRecordWriter(output_path) as writer:
for data in data_list:
# 创建一个 Example
feature = {
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[data['label']])),
'text': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data['text'].encode()]))
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
# 使用示例
data = [{'label': 1, 'text': 'hello'}, {'label': 2, 'text': 'world'}]
write_tfrecord_single('data.tfrecord', data)
多文件写入
对于大数据集,可以写入多个 TFRecord 文件以提高并行处理效率。
def write_tfrecord_multi(output_prefix, data_list, num_shards=2):
# 将数据分配到多个文件中
for shard_id in range(num_shards):
output_path = f'{output_prefix}_{shard_id}.tfrecord'
with tf.io.TFRecordWriter(output_path) as writer:
# 写入每个分片的数据
for i in range(shard_id, len(data_list), num_shards):
feature = {
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[data_list[i]['label']])),
'text': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data_list[i]['text'].encode()]))
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
# 使用示例
write_tfrecord_multi('data_sharded', data, num_shards=2)
2. 从 TFRecord 文件创建 tf.data.Dataset
读取 TFRecord 文件并使用 tf.data.TFRecordDataset 创建高效的数据集,便于训练模型。
def create_dataset(file_pattern, parse_function):
# 创建数据集
dataset = tf.data.TFRecordDataset(file_pattern)
dataset = dataset.map(parse_function) # 解析数据
dataset = dataset.shuffle(buffer_size=100).batch(32).prefetch(tf.data.AUTOTUNE)
return dataset
# 解析函数示例
def parse_example(serialized_example):
feature_description = {
'label': tf.io.FixedLenFeature([], tf.int64),
'text': tf.io.FixedLenFeature([], tf.string)
}
parsed_features = tf.io.parse_single_example(serialized_example, feature_description)
return parsed_features['label'], parsed_features['text']
# 使用示例:读取单文件
single_dataset = create_dataset('data.tfrecord', parse_example)
# 使用示例:读取多文件(使用通配符)
multi_dataset = create_dataset('data_sharded_*.tfrecord', parse_example)
3. TFRecord 文件的压缩与解压缩(优化存储)
TFRecord 文件支持 GZIP 压缩,以减少磁盘占用并可能加速 I/O 操作。
写入压缩文件
在写入时启用压缩。
def write_compressed_tfrecord(output_path, data_list):
with tf.io.TFRecordWriter(output_path, options=tf.io.TFRecordOptions(compression_type='GZIP')) as writer:
for data in data_list:
feature = {
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[data['label']])),
'text': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data['text'].encode()]))
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
# 使用示例
write_compressed_tfrecord('data_compressed.tfrecord.gz', data)
读取压缩文件
读取时指定相同的压缩类型。
def read_compressed_dataset(file_path):
dataset = tf.data.TFRecordDataset(file_path, compression_type='GZIP')
dataset = dataset.map(parse_example) # 使用相同的解析函数
return dataset
# 使用示例
compressed_dataset = read_compressed_dataset('data_compressed.tfrecord.gz')
总结与最佳实践
- 写入: 对于小数据使用单文件,大数据使用多文件分片以提高效率。
- 读取: 使用
tf.data.TFRecordDataset和解析函数创建可迭代数据集,结合 shuffle 和 batch 优化性能。 - 压缩: 启用 GZIP 压缩可以显著减少存储空间,特别是在处理文本或重复数据时。
- 注意事项: 确保写入和读取时的特征定义一致;压缩可能会轻微增加 CPU 开销,但通常利大于弊。
通过本教程,您应该能够掌握 TFRecord 文件的基本操作,从而在 TensorFlow 项目中高效管理数据。
开发工具推荐