展平和重塑输入到编码器

我有一个包含3D固定长度段的数据集,形状为(1,200,4),我想将其输入到一个具有全连接层的自编码器中,类似于以下结构:

 encoderautoencoder.add(Dense(200, input_shape=(self.input_dim,), activation='relu'))  autoencoder.add(Dense(100, input_shape=(self.input_dim,), activation='relu'))  autoencoder.add(Dense(80, input_shape=(self.input_dim,), activation='linear'))  autoencoder.add(Dense(encoding_dim, input_shape=(self.input_dim,), activation='linear'))         #decoderautoencoder.add(Dense(80, activation='linear')) autoencoder.add(Dense(100, activation='relu')) autoencoder.add(Dense(200, activation='relu'))  autoencoder.add(Dense(self.input_dim, activation='linear'))

我知道在输入到这个网络之前需要展平输入,我应该在编码器的第一个Dense()层之前添加一个Flatten()层。

我不确定的是,如何在解码器部分将输入重构回其原始形状(1,200,4)

如果这有帮助的话,输入是由固定长度段组成的numpy-ndarray,例如:

print(train_x.shape)(140508, 1, 200, 4)

例如,假设输入是这样的:

X = np.random.randn(3400, 1, 200, 4)X.shape(3400, 1, 200, 4)

我如何在编码器中对这个输入应用展平层,然后在解码器中重构它,以实现类似于上述的全连接自编码器?

对此有何建议?


回答:

这里有一个可能的解决方案… 你只需在最后注意维度,并进行适当的重塑操作

encoding_dim = 20input_shape = (1,200,4)n_sample = 100X = np.random.uniform(0,1, (n_sample,) + input_shape)autoencoder = Sequential()autoencoder.add(Flatten(input_shape=input_shape))autoencoder.add(Dense(200, activation='relu'))  autoencoder.add(Dense(100, activation='relu'))  autoencoder.add(Dense(80, activation='relu'))  autoencoder.add(Dense(encoding_dim, activation='relu'))         #decoderautoencoder.add(Dense(80, activation='relu')) autoencoder.add(Dense(100, activation='relu')) autoencoder.add(Dense(200, activation='relu'))  autoencoder.add(Dense(np.prod(input_shape), activation='linear'))autoencoder.add(Reshape(input_shape))print(autoencoder.summary())autoencoder.compile('adam', 'mse')autoencoder.fit(X,X, epochs=3)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注