TensorFlow自定义估计器预测时抛出值错误

注意:此问题附带了一个记录详尽的Colab笔记本。

TensorFlow的文档有时让人觉得不够详尽。一些较旧的低级API文档似乎已被删除,大多数新文档都指向使用更高级的API,例如TensorFlow的keras子集或estimators。如果高级API不经常依赖于其低级别,这不会构成太大问题。以estimators为例(特别是使用TensorFlow Records时的input_fn)。

在以下Stack Overflow帖子中:

在TensorFlow / StackOverflow社区的慷慨帮助下,我们已经更接近于实现TensorFlow“创建自定义估计器”指南未能实现的目标,展示如何制作一个在实践中可能真正使用的估计器(而不是玩具示例),例如一个能够:

  • 在性能恶化时使用验证集进行早期停止,
  • 从TF Records读取数据,因为许多数据集比TensorFlow推荐的1Gb内存更大,并且
  • 在训练时保存其最佳版本

尽管我对这方面还有很多疑问(从将数据编码到TF Record的最佳方式,到serving_input_fn到底期望什么),但有一个问题比其他问题更加突出:

如何使用我们刚刚创建的自定义估计器进行预测?

predict的文档中,它指出:

input_fn:一个构造特征的函数。预测将持续进行,直到input_fn引发输入结束异常(tf.errors.OutOfRangeErrorStopIteration)。有关更多信息,请参阅预制估计器。该函数应构造并返回以下之一:

  • 一个tf.data.Dataset对象:Dataset对象的输出必须具有与下文相同的约束。
  • 特征:一个tf.Tensor或一个字符串特征名称到Tensor的字典。特征由model_fn消费。它们应该满足model_fn对输入的期望。
  • 一个元组,在这种情况下,第一个项目被提取为特征。

(可能)最有可能的是,如果一个人使用estimator.predict,他们正在使用内存中的数据,例如一个密集张量(因为保留的测试集可能会通过evaluate)。

所以我在附件的Colab中创建了一个单一的密集示例,将其包装在tf.data.Dataset中,并调用predict来获取ValueError

如果有人能向我解释如何:

  1. 加载我的保存的估计器
  2. 给定一个内存中的密集示例,使用估计器预测输出

回答:

to_predict = random_onehot((1, SEQUENCE_LENGTH, SEQUENCE_CHANNELS))\        .astype(tf_type_string(I_DTYPE))pred_features = {'input_tensors': to_predict}pred_ds = tf.data.Dataset.from_tensor_slices(pred_features)predicted = est.predict(lambda: pred_ds, yield_single_examples=True)next(predicted)

ValueError: Tensor(“IteratorV2:0”, shape=(), dtype=resource) must be from the same graph as Tensor(“TensorSliceDataset:0”, shape=(), dtype=variant).

当您使用tf.data.Dataset模块时,它实际上定义了一个与模型图独立的输入图。这里发生的情况是,您首先通过调用tf.data.Dataset.from_tensor_slices()创建了一个小图,然后估计器API通过自动调用dataset.make_one_shot_iterator()创建了第二个图。这两个图无法通信,因此抛出错误。

为了避免这种情况,您永远不应该在estimator.train/evaluate/predict之外创建数据集。这就是为什么所有与数据相关的内容都被包装在输入函数中。

def predict_input_fn(data, batch_size=1):  dataset = tf.data.Dataset.from_tensor_slices(data)  return dataset.batch(batch_size).prefetch(None)predicted = est.predict(lambda: predict_input_fn(pred_features), yield_single_examples=True)next(predicted)

现在,图不会在predict调用之外创建。

我还添加了dataset.batch(),因为您的其他代码期望批处理数据,并且它会抛出形状错误。预取只是为了加速处理。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注