注意:此问题附带了一个记录详尽的Colab笔记本。
TensorFlow的文档有时让人觉得不够详尽。一些较旧的低级API文档似乎已被删除,大多数新文档都指向使用更高级的API,例如TensorFlow的keras
子集或estimators
。如果高级API不经常依赖于其低级别,这不会构成太大问题。以estimators
为例(特别是使用TensorFlow Records时的input_fn
)。
在以下Stack Overflow帖子中:
- TensorFlow v1.10:以字节字符串还是按通道存储图像?
- TensorFlow 1.10 TFRecordDataset – 恢复TFRecords
- TensorFlow v1.10+ 为什么在创建检查点时需要输入服务接收函数?
- TensorFlow 1.10+ 自定义估计器使用train_and_evaluate进行早期停止
- TensorFlow自定义估计器在训练后调用evaluate时卡住
在TensorFlow / StackOverflow社区的慷慨帮助下,我们已经更接近于实现TensorFlow“创建自定义估计器”指南未能实现的目标,展示如何制作一个在实践中可能真正使用的估计器(而不是玩具示例),例如一个能够:
- 在性能恶化时使用验证集进行早期停止,
- 从TF Records读取数据,因为许多数据集比TensorFlow推荐的1Gb内存更大,并且
- 在训练时保存其最佳版本
尽管我对这方面还有很多疑问(从将数据编码到TF Record的最佳方式,到serving_input_fn
到底期望什么),但有一个问题比其他问题更加突出:
如何使用我们刚刚创建的自定义估计器进行预测?
在predict的文档中,它指出:
input_fn
:一个构造特征的函数。预测将持续进行,直到input_fn
引发输入结束异常(tf.errors.OutOfRangeError
或StopIteration
)。有关更多信息,请参阅预制估计器。该函数应构造并返回以下之一:
- 一个
tf.data.Dataset
对象:Dataset对象的输出必须具有与下文相同的约束。- 特征:一个
tf.Tensor
或一个字符串特征名称到Tensor
的字典。特征由model_fn
消费。它们应该满足model_fn
对输入的期望。- 一个元组,在这种情况下,第一个项目被提取为特征。
(可能)最有可能的是,如果一个人使用estimator.predict
,他们正在使用内存中的数据,例如一个密集张量(因为保留的测试集可能会通过evaluate
)。
所以我在附件的Colab中创建了一个单一的密集示例,将其包装在tf.data.Dataset
中,并调用predict
来获取ValueError
。
如果有人能向我解释如何:
- 加载我的保存的估计器
- 给定一个内存中的密集示例,使用估计器预测输出
回答:
to_predict = random_onehot((1, SEQUENCE_LENGTH, SEQUENCE_CHANNELS))\ .astype(tf_type_string(I_DTYPE))pred_features = {'input_tensors': to_predict}pred_ds = tf.data.Dataset.from_tensor_slices(pred_features)predicted = est.predict(lambda: pred_ds, yield_single_examples=True)next(predicted)
ValueError: Tensor(“IteratorV2:0”, shape=(), dtype=resource) must be from the same graph as Tensor(“TensorSliceDataset:0”, shape=(), dtype=variant).
当您使用tf.data.Dataset
模块时,它实际上定义了一个与模型图独立的输入图。这里发生的情况是,您首先通过调用tf.data.Dataset.from_tensor_slices()
创建了一个小图,然后估计器API通过自动调用dataset.make_one_shot_iterator()
创建了第二个图。这两个图无法通信,因此抛出错误。
为了避免这种情况,您永远不应该在estimator.train/evaluate/predict之外创建数据集。这就是为什么所有与数据相关的内容都被包装在输入函数中。
def predict_input_fn(data, batch_size=1): dataset = tf.data.Dataset.from_tensor_slices(data) return dataset.batch(batch_size).prefetch(None)predicted = est.predict(lambda: predict_input_fn(pred_features), yield_single_examples=True)next(predicted)
现在,图不会在predict调用之外创建。
我还添加了dataset.batch()
,因为您的其他代码期望批处理数据,并且它会抛出形状错误。预取只是为了加速处理。