在Keras中训练变分自编码器引发”InvalidArgumentError: Incompatible shapes”错误

我一直在尝试让这个VAE正常工作,但整个晚上都在反复遇到相同的问题。我不确定问题出在哪里。我尝试过移除回调函数、验证、更改损失函数、更改采样方法。虽然显示的错误是提前停止,但错误始终出现在fit函数的最后一个参数上。我已经没有其他想法来解决这个问题了。

下面是可复现的代码,随后是反复出现的错误。请注意,改变批次大小确实会改变错误,但不匹配的数量也会随批次大小减少而减少。

import pandas as pdfrom sklearn.datasets import make_blobs from sklearn.preprocessing import MinMaxScalerimport keras.backend as Kimport tensorflow as tffrom keras.layers import Input, Dense, Lambda, Layer, Add, Multiplyfrom keras.models import Model, Sequentialfrom keras.callbacks import EarlyStopping, LearningRateSchedulerfrom keras.objectives import binary_crossentropyx, labels = make_blobs(n_samples=150000, n_features=110,  centers=16, cluster_std=4.0)scaler = MinMaxScaler()x = scaler.fit_transform(x)x = pd.DataFrame(x)train = x.sample(n = 100000)train_indexs = train.index.valuestest = x[~x.index.isin(train_indexs)]print(train.shape, test.shape)min_dim = 2batch_size = 1024def sampling(args):    mu, log_sigma = args    eps = K.random_normal(shape=(batch_size, min_dim), mean = 0.0, stddev = 1.0)    return mu + K.exp(0.5 * log_sigma) * eps#Encoderinputs = Input(shape=(x.shape[1],))down1 = Dense(64, activation='relu')(inputs)mu = Dense(min_dim, activation='linear')(down1)log_sigma = Dense(min_dim, activation='linear')(down1)#Samplingsample_set = Lambda(sampling, output_shape=(min_dim,))([mu, log_sigma])#decoderup1 = Dense(64, activation='relu')(sample_set)output = Dense(x.shape[1], activation='sigmoid')(up1)vae = Model(inputs, output)encoder = Model(inputs, mu)def vae_loss(y_true, y_pred):    recon  = binary_crossentropy(y_true, y_pred)    kl = - 0.5 * K.mean(1 + log_sigma - K.square(mu) - K.exp(log_sigma), axis=-1)    return recon + klvae.compile(optimizer='adam', loss=vae_loss)vae.fit(train, train, shuffle = True, epochs = 1000,         batch_size = batch_size, validation_data = (test, test),         callbacks = [EarlyStopping(patience=50)])

错误:

  File "<ipython-input-2-7aa4be21434d>", line 62, in <module>    callbacks = [EarlyStopping(patience=50)])  File "C:\Users\se01040434\Anaconda3\lib\site-packages\keras\engine\training.py", line 1239, in fit    validation_freq=validation_freq)  File "C:\Users\se01040434\Anaconda3\lib\site-packages\keras\engine\training_arrays.py", line 196, in fit_loop    outs = fit_function(ins_batch)  File "C:\Users\se01040434\Anaconda3\lib\site-packages\tensorflow\python\keras\backend.py", line 3792, in __call__    outputs = self._graph_fn(*converted_inputs)  File "C:\Users\se01040434\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py", line 1605, in __call__    return self._call_impl(args, kwargs)  File "C:\Users\se01040434\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py", line 1645, in _call_impl    return self._call_flat(args, self.captured_inputs, cancellation_manager)  File "C:\Users\se01040434\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py", line 1746, in _call_flat    ctx, args, cancellation_manager=cancellation_manager))  File "C:\Users\se01040434\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py", line 598, in call    ctx=ctx)  File "C:\Users\se01040434\Anaconda3\lib\site-packages\tensorflow\python\eager\execute.py", line 60, in quick_execute    inputs, attrs, num_outputs)InvalidArgumentError:  Incompatible shapes: [672] vs. [1024]     [[node gradients/loss/dense_5_loss/vae_loss/weighted_loss/mul_grad/Mul_1 (defined at C:\Users\se01040434\Anaconda3\lib\site-packages\keras\backend\tensorflow_backend.py:3009) ]] [Op:__inference_keras_scratch_graph_1515]Function call stack:keras_scratch_graph

回答:

您创建了一个具有batch_size样本的随机张量,其中batch_size是代码中预设的固定值。然而,请注意,模型不一定需要batch_size个输入样本(例如,训练/测试数据的最后一个批次可能样本数量较少)。在这些情况下,如果您的模型实现依赖于批次大小的动态值,您应该使用keras.backend.shape函数动态获取它:

def sampling(args):    # ...    eps = K.random_normal(shape=(K.shape(mu)[0], min_dim)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注