如何使用Keras的Conv2D层处理可变形状的输入

我有一个名为X_train的NP数组,具有以下属性:

X_train.shape = (139,)X_train[0].shape = (210, 224, 3)X_train[1].shape = (220,180, 3)

换句话说,有139个观测值。每张图像的宽度和高度不同,但它们都有3个通道。因此,维度应该是(139, None, None, 3),其中None表示可变的维度。

由于在层中不包括观测数量的维度,我在Conv2D层中使用了input_shape=(None,None,3)。但这会导致以下错误:

expected conv2d_1_input to have 4 dimensions, but got array with shape (139, 1)

我的猜测是问题出在输入形状是(139,)而不是(139, None, None, 3)。然而,我不确定如何转换成后者。


回答:

解决您问题的一种可能方法是用零填充数组,使它们都具有相似的尺寸。之后,您的输入形状将类似于(139, max_x_dimension, max_y_dimension, 3)

以下函数可以完成这项工作:

import numpy as npdef fillwithzeros(inputarray, outputshape):    """    Fills input array with dtype 'object' so that all arrays have the same shape as 'outputshape'    inputarray: input numpy array    outputshape: max dimensions in inputarray (obtained with the function 'findmaxshape')    output: inputarray filled with zeros    """    length = len(inputarray)    output = np.zeros((length,)+outputshape, dtype=np.uint8)    for i in range(length):        output[i][:inputarray[i].shape[0],:inputarray[i].shape[1],:] = inputarray[i]    return outputdef findmaxshape(inputarray):    """    Finds maximum x and y in an inputarray with dtype 'object' and 3 dimensions    inputarray: input numpy array    output: detected maximum shape    """    max_x, max_y, max_z = 0, 0, 0    for array in inputarray:        x, y, z = array.shape        if x > max_x:            max_x = x        if y > max_y:            max_y = y        if z > max_z:            max_z = z    return(max_x, max_y, max_z)#Create random data similar to your datarandom_data1 = np.random.randint(0,255, 210*224*3).reshape((210, 224, 3))random_data2 = np.random.randint(0,255, 220*180*3).reshape((220, 180, 3))X_train = np.array([random_data1, random_data2])#Convert X_train so that all images have the same shapenew_shape = findmaxshape(X_train)new_X_train = fillwithzeros(X_train, new_shape)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注