我在尝试实现TensorFlow提供的高级API,特别是基础分类器。然而,在尝试训练模型时,我遇到了以下错误:
错误:
NotFoundError (see above for traceback): Key baseline/bias not found in checkpoint [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]
代码:
import tensorflow as tfimport numpy as npfrom sklearn import datasetsfrom sklearn.model_selection import train_test_splitdef digit_cross(): # Number of classes, one class for each of 10 digits. num_classes = 10 digit = datasets.load_digits() x = digit.data y = digit.target x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.3, random_state=42) y_train_index = np.arange(y_train.size) train_input_fn = tf.estimator.inputs.numpy_input_fn( x={"x": np.array(x_train)}, y=np.array(y_train), num_epochs=None, shuffle=False) # Build BaselineClassifier classifier = tf.estimator.BaselineClassifier(n_classes=num_classes, model_dir="./checkpoints_tutorial17-1/") # Fit model. classifier.train(train_input_fn)digit_cross()
回答:
看起来你在 model_dir="./checkpoints_tutorial17-1/"
中有一个来自其他模型的检查点,而不是来自基础分类器。具体来说,你在这个文件夹中有一个检查点文件和model.ckpt-*文件。
正如TensorFlow文档中所述:
- model_dir: 用于保存模型参数、图形等的目录。这也可以用来从目录中加载检查点到估算器中,以继续训练之前保存的模型。 如果是PathLike对象,路径将被解析。如果为None,将使用配置中的model_dir(如果已设置)。如果两者都设置,它们必须相同。如果两者都为None,将使用临时目录。
在这里,BaselineClassifier
首先会构建一个使用 baseline/bias
的图。然后它发现 model_dir
中有一个之前的检查点。它会尝试加载这个检查点,如果你已经执行了 tf.logging.set_verbosity(tf.logging.INFO)
,你应该会看到类似于以下信息的提示:
"INFO:tensorflow:Restoring parameters from .../checkpoints_tutorial17-1\model.ckpt-..."
因为 model_dir
中的这个检查点不是来自 BaselineClassifier
,所以它不会有 baseline/bias
。BaselineClassifier
找不到它,因此会抛出错误。