TensorFlow 在拟合 TensorForestEstimator 时崩溃

我试图用表示 7 个特征和 7 个标签的数值浮点数据来拟合 TensorForestEstimator 模型。也就是说,featureslabels 的形状都是 (484876, 7)。我在 ForestHParams 中适当地设置了 num_classes=7num_features=7。数据的格式如下:

f1       f2     f3    f4      f5    f6    f7   l1       l2       l3       l4       l5       l6       l7
39000.0  120.0  65.0  1000.0  25.0  0.69  3.94 39000.0  39959.0  42099.0  46153.0  49969.0  54127.0  55911.0
32000.0  185.0  65.0  1000.0  75.0  0.46  2.19 32000.0  37813.0  43074.0  48528.0  54273.0  60885.0  63810.0 
30000.0  185.0  65.0  1000.0  25.0  0.41  1.80 30000.0  32481.0  35409.0  39145.0  42750.0  46678.0  48595.0

当调用 fit() 时,Python 崩溃并显示以下消息:

Python 在使用 _pywrap_tensorflow_internal.so 插件时意外退出。

启用 tf.logging.set_verbosity('INFO') 时的输出如下:

INFO:tensorflow:training graph for tree: 0
INFO:tensorflow:training graph for tree: 1
...
INFO:tensorflow:training graph for tree: 9998
INFO:tensorflow:training graph for tree: 9999
INFO:tensorflow:Create CheckpointSaverHook.
2017-07-26 10:25:30.908894: F tensorflow/contrib/tensor_forest/kernels/count_extremely_random_stats_op.cc:404] Check failed: column < num_classes_ (39001 vs. 8)
Process finished with exit code 134 (interrupted by signal 6: SIGABRT)

我不确定这个错误的含义,它似乎并不合理,因为 num_classes=7,而不是 8,并且由于特征和标签的形状是 (484876, 7),我不知道 39001 是从哪里来的。

以下是复现问题的代码:

import numpy as np
import pandas as pd
import os
def get_training_data():
    training_file = "data.txt"
    data = pd.read_csv(training_file, sep='\t')
    X = np.array(data.drop('Result', axis=1), dtype=np.float32)
    y = []
    for e in data.ResultStr:
        y.append(list(np.array(str(e).replace('[', '').replace(']', '').split(','))))
    y = np.array(y, dtype=np.float32)
    features = tf.constant(X)
    labels = tf.constant(y)
    return features, labels
hyperparameters = ForestHParams(
    num_trees=100,
    max_nodes=10000,
    bagging_fraction=1.0,
    num_splits_to_consider=0,
    feature_bagging_fraction=1.0,
    max_fertile_nodes=0,
    split_after_samples=250,
    min_split_samples=5,
    valid_leaf_threshold=1,
    dominate_method='bootstrap',
    dominate_fraction=0.99,
    # 以上所有参数均为默认值
    num_classes=7,
    num_features=7)
estimator = TensorForestEstimator(
    params=hyperparameters,
    # 以下所有参数均为默认值
    device_assigner=None,
    model_dir=None,
    graph_builder_class=RandomForestGraphs,
    config=None,
    weights_name=None,
    keys_name=None,
    feature_engineering_fn=None,
    early_stopping_rounds=100,
    num_trainers=1,
    trainer_id=0,
    report_feature_importances=False,
    local_eval=False)
estimator.fit(
    input_fn=lambda: get_training_data(),
    max_steps=100,
    monitors=[
        TensorForestLossHook(
            early_stopping_rounds=30
        )
    ])

如果我用 SKCompat 包装它,仍然无法工作,会出现相同的错误。导致此崩溃的原因是什么?


回答:

需要在 ForestHParams 中指定 regression=True,因为 TensorForestEstimator 默认假设它被用于解决分类问题,只能输出一个值。

在初始化估计器时会隐式创建一个 num_outputs 变量,如果未指定 regression,它将被设置为 1。如果指定了 regression,那么 num_outputs = num_classes,并正常保存检查点。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注