如何打印特定层张量的值

我想检查mask应用于张量后的值。

这是模型的部分截图。我设置temp = x,这样我稍后可以打印temp来检查具体的值。

假设我们有一个使用声学特征的4类分类模型,数据格式为(1000,50,136),分别代表(批次,时间步,特征)。

我们的目标是检查模型是否按时间步学习特征。换句话说,我们希望确认模型是通过如图中红色矩形所示的切片方式学习的。从逻辑上讲,这是Keras LSTM层的处理方式,但当参数改变时(例如,Dense层的单元数),生成的混淆矩阵差异很大。验证准确率保持在45%,因此我们希望可视化模型。

提议的想法是打印出第一批次的第一步,并打印模型中的输入。如果它们相同,那么模型是以正确的方式学习的(一次处理(136,1)个特征),而不是一次处理单个特征的(50,1)个时间步。

enter image description here

input_feature = Input(shape=(X_train.shape[1],X_train.shape[2]))x = Masking(mask_value=0)(input_feature)temp = xx = Dense(Dense_unit,kernel_regularizer=l2(dense_reg), activation='relu')(x)        

我尝试过tf.print(),但得到了AttributeError: 'Tensor' object has no attribute '_datatype_enum'的错误


如Lescurel在获取非最终Keras模型层的输出中建议的。

model2 = Model(inputs=[input_attention, input_feature], outputs=model.get_layer('masking')).outputprint(model2.predict(X_test))AttributeError: 'Masking' object has no attribute 'op'

回答:

您想在掩码后输出。Lescurel在评论中的链接展示了如何做到这一点。这个GitHub链接也是如此。

您需要创建一个新模型,

  • 输入来自您的模型的输入
  • 输出来自该层的输出

我用从您的代码片段中衍生的虚拟代码进行了测试。

import numpy as npfrom keras import Inputfrom keras.layers import Masking, Densefrom keras.regularizers import l2from keras.models import Sequential, ModelX_train = np.random.rand(4,3,2)Dense_unit = 1dense_reg = 0.01mdl = Sequential()mdl.add(Input(shape=(X_train.shape[1],X_train.shape[2]),name='input_feature'))mdl.add(Masking(mask_value=0,name='masking'))mdl.add(Dense(Dense_unit,kernel_regularizer=l2(dense_reg),activation='relu',name='output_feature'))mdl.summary()mdl2mask = Model(inputs=mdl.input,outputs=mdl.get_layer("masking").output)maskoutput = mdl2mask.predict(X_train)mdloutput = mdl.predict(X_train)maskoutput # 打印掩码后的输出mdloutput # 打印模型的输出maskoutput.shape #(4, 3, 2): 掩码层的形状与之前的层(这里是输入)相同mdloutput.shape #(4, 3, 1): Dense层的输出形状

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注