如何在CNN Keras中删除特定的滤波器

假设我已经按照以下方式创建了一个模型

from tensorflow.keras.models import Sequentialfrom tensorflow.keras.layers import Dense, Flatten, Conv2Dmodel = Sequential()model.add(Conv2D(32, kernel_size=(3, 3),                     activation='relu',                     input_shape=input_shape,                     kernel_initializer='he_normal',))model.add(Flatten())model.add(Dense(10, activation='softmax'))

在完成模型训练后,我如何删除索引为1、5和9的滤波器?这样剩余的滤波器总数将是29个,但不包括原来位于1、5和9位置的滤波器。

我想使用这个“修改后的模型”再次对测试数据进行预测,但不使用那些滤波器

score = modified_model.evaluate(x, y)

回答:

这是将训练好的网络中所需的滤波器替换为零的解决方案

inp = Input((10,10,3))c = Conv2D(32, kernel_size=(3, 3),           activation='relu',           kernel_initializer='he_normal')f = Flatten()d = Dense(10, activation='softmax')x = c(inp)x = f(x)out = d(x)model = Model(inp, out)print(model.summary())model.fit(.....)w,b = c.get_weights()w[:,:,:,1] = 0w[:,:,:,5] = 0w[:,:,:,9] = 0c.set_weights([w,b])

关于修改训练好的网络,删除权重是不可能的。在你的情况下,这与下面的层(Flatten和Dense)不兼容

w,b = c.get_weights()w = np.delete(w, [1,5,9], -1)b = np.delete(b, [1,5,9], 0)new_c = Conv2D(29, kernel_size=(3, 3),               activation='relu',               kernel_initializer='he_normal',               trainable=False)x = new_c(inp)x = f(x)out = d(x) # -----> 错误!new_model= Model(inp, out)new_c.set_weights([w,b])print(new_model.summary())

你可以创建一个新的网络,在这个网络中管理旧的Conv2D滤波器,但你需要重新训练下面的层

w,b = c.get_weights()w = np.delete(w, [1,5,9], -1)b = np.delete(b, [1,5,9], 0)new_inp = Input((10,10,3))new_c = Conv2D(29, kernel_size=(3, 3),           activation='relu',           kernel_initializer='he_normal',           trainable=False)new_f = Flatten()new_d = Dense(10, activation='softmax')new_x = new_c(new_inp)new_x = new_f(new_x)new_out = new_d(new_x)new_model = Model(new_inp, new_out)new_c.set_weights([w,b])print(new_model.summary())new_model.fit(.....)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注