背景
我们的数据存储在.tfrecord
文件中,X
是我们的训练数据,包含40x40
的灰度图像,而Y
是标签。这些图像按顺序排列(顺序很重要)。我们希望使用TensorFlow的Estimator API输入这些图像,以训练神经网络模型(例如:LSTM),并使用GoogleML处理各种时间窗口大小和偏移量。
问题
如何将输入的特征字符串重塑为特定长度的序列,例如,将1000
张图像放入一个序列中,然后对这些序列进行窗口化处理,例如,获取50
张图像的窗口,窗口偏移为25
?
当前状态
我们已经设法实现了这一点(下面的稀疏示例),但没有首先将数据重塑为长度为1000
的集合,结果是窗口从一个集合的第975个元素跨越到下一个集合的第25个元素,这是我们不希望看到的。我们需要重叠的窗口,这些窗口从每组1000
张图像的开始到结束,但不能跨越它们的边界。
import tensorflow as tf# .tfrecord文件包含数据'X'和标签'Y'dataset = tf.data.TFRecordDataset('.tfrecord文件')# 为dataset.map函数定义解析函数def _parse_function(proto): # 定义解析常量 image_size = 40 num_channels = 1 num_classes = 3 # 定义您的tfrecord特征键并 # 将一维数组重塑为二维数组(图像) keys_to_features = {'X': tf.FixedLenFeature([image_size, image_size, num_channels], tf.float32), # 图像高度,图像宽度,通道数 'Y': tf.FixedLenFeature([], tf.int64)} # 加载一个示例 parsed_features = tf.parse_single_example(proto, keys_to_features) # 提取图像和标签 image = parsed_features['X'] labels = tf.cast( parsed_features['Y'], tf.int32 ) labels = tf.one_hot( labels, depth=num_classes ) # 独热编码 return image, labels# 将数据重塑为解析格式dataset = dataset.map(_parse_function)# 定义数据集参数window_size = 50batch_size = 500window_shift = int( window_size / 2 ) # 25# 实现滑动窗口 dataset = dataset.window(size=window_size, shift=window_shift, drop_remainder=True ).flat_map( lambda x: x.batch(window_size) )# 批处理数据dataset = dataset.batch(batch_size)# 创建迭代器# iterator = dataset.make_one_shot_iterator().get_next()
上面的iterator
将为X
数据返回一个形状为(batch_size, window_size, image_height, image_width, number of channels)的张量,在我们的案例中为(500, 50, 40, 40, 1)
,而Y
为(500, 3)
的数组。
回答:
我通过过滤掉跨越边界的窗口来实现这一点。一旦你有了解析的特征,就对所有内容应用窗口化,然后计算哪些窗口是溢出的并过滤掉它们:
ds = tf.data.TFRecordDataset( filename )ds = ds.map( _parse_function )# 应用窗口化ds = ds.window( size=50, shift=25, drop_remainder=True ).flat_map( lambda x, y: tf.data.Dataset.zip( (x.batch(50), y.batch(50)) ) )# 枚举数据集并过滤每40个窗口ds = ds.apply( tf.data.experimental.enumerate_dataset(start=1) ).filter( lambda i, x: tf.not_equal( i % 40, 0) )# 去除枚举ds = ds.map( lambda i, x: x )# 批处理、洗牌等......
澄清:每40个窗口都被过滤掉,因为如果你有长度为1000的集合和窗口偏移为25,那么将会有set_len / win_shift = 40
个窗口,最后一个(即第40个)将溢出到下一组。请注意,枚举从1开始,因此第0个样本不会被取出,因为0 % x == 0
。
请注意,这更像是一种权宜之计,而非真正的解决方案。它在50%的重叠下工作得很好,但在其他百分比下,计算要丢弃的索引会变得更加复杂(在>50%重叠的情况下,会有多于一个窗口溢出到下一组,因此需要多个过滤器)。