在Keras中使用fit_generator

我在使用Keras中的fit_generator时还是新手,所以我尝试编写一个简单的脚本来帮助我理解它的工作原理。

X = np.array([[1,2],[10,3],[2,4],[20,5],[30,1],[3,5],[4,6],[7,4],[5,10],[1,7]])Y = np.array([[2,3],[30,13],[8,6],[100,25],[30,31],[15,8],[24,10],[28,11],[50,15],[7,8]])def generator(feat,labels):    i=0    while (True):        yield feat[i],labels[i]        i+=1model_fnn = tf.keras.models.Sequential()model_fnn.add(tf.keras.layers.Dense(50, input_dim=X.shape[1], activation=tf.nn.relu))model_fnn.add(tf.keras.layers.Dense(Y.shape[1], activation=tf.keras.activations.linear))nb_epoch = 3000model_fnn.compile(optimizer='adam', loss='mean_squared_error', metrics=['accuracy'])model_fnn.fit_generator(generator(X,Y), steps_per_epoch=10, epochs=nb_epoch, verbose=0)

但它给我报了错误:

ValueError: Error when checking input: expected dense_2_input to have shape (2,) but got array with shape (1,)

谁能帮帮我?谢谢!


回答:

这个错误似乎是因为你需要将每个输入/标签包装成一个Numpy数组,这将是训练批次。这个生成器的代码应该是这样的:

def generator(feat,labels):    i=0    while (True):        yield np.array([feat[i]]), np.array([labels[i]])        i+=1

这样你的错误应该会解决。你将使用批量大小为1进行训练,因为传递给训练的每个数组只包含一个对象。

但是,为了使用带有yield的生成器进行训练,你需要确保循环不会崩溃,因为它是一个无限循环,而你没有无限的数据。这可以通过使用itertools.cycle来实现:

import itertoolsdef generator(feat,labels):    pairs = [(x, y) for x in feat for y in labels]    cycle_pairs = itertools.cycle(pairs)    while (True):        f, p = next(cycle_pairs)        return np.array([f]), np.array([p])

此外,为了进一步使用yield生成器,你可以给函数添加一个参数来指定批量大小(正如PyGirl的评论中所述)。这看起来应该像这样:

def generator(feat, labels, batch_size):    pairs = [(x, y) for x in feat for y in labels]    cycle_pairs = itertools.cycle(pairs)    while (True):        x = []        y = []        for _ in range(batch_size):            f, p = next(cycle_pairs)            x.append(f)            y.append(p)        yield np.array(x), np.array(y)

这意味着,在每一步中,生成器将yield一个包含batch_size个元素的Numpy数组以供训练。

为了更好地理解如何在Keras中使用fit_generator进行训练,我还建议阅读一些关于Keras Sequence Util的内容,使用它比生成器更安全。

希望这对你有帮助!

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注