为什么Keras会给我数组中那么多的权重,这些权重是用来做什么的?

以上代码给我返回了六个权重数组。我理解weights[0]是一个2×20的连接集合,但weights[1]只是数组中的20个成员,可能是每个神经元一个。但我以为权重是连接之间的,所以这到底是什么?

结果如下:

Model: "sequential_11"_________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================flatten_11 (Flatten)         (None, 2)                 0         _________________________________________________________________dense_33 (Dense)             (None, 20)                60        _________________________________________________________________dense_34 (Dense)             (None, 20)                420       _________________________________________________________________dense_35 (Dense)             (None, 1)                 21        =================================================================Total params: 501Trainable params: 501Non-trainable params: 0_________________________________________________________________None<tf.Variable 'dense_33/kernel:0' shape=(2, 20) dtype=float32, numpy=array([[-0.08592772, -0.4262397 ,  0.32593143,  0.40175033, -0.11370629,         0.29291457, -0.33887625,  0.09051579, -0.11669007,  0.15766495,        -0.03898111,  0.47355425,  0.4038219 , -0.16283795, -0.52166206,         0.08563775,  0.10119641,  0.35014063, -0.29258126,  0.11257637],       [ 0.32310146, -0.00564504, -0.39950165,  0.3422314 , -0.1736508 ,        -0.15470237,  0.03384084, -0.50031585,  0.17582124, -0.20669848,         0.38023835,  0.45190394,  0.22054166, -0.3583283 , -0.31276733,        -0.42144495,  0.05265975,  0.28793246, -0.12343103, -0.52028173]],      dtype=float32)><tf.Variable 'dense_33/bias:0' shape=(20,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,       0., 0., 0.], dtype=float32)><tf.Variable 'dense_34/kernel:0' shape=(20, 20) dtype=float32, numpy=array([[ 0.3640272 , -0.1974021 , -0.18908244,  0.36629468, -0.19073088,         0.3649723 ,  0.20782089,  0.3469664 ,  0.07804716,  0.31226724,        -0.29463375,  0.1686036 , -0.20033775,  0.30975252,  0.02273259,        -0.20266692, -0.27747104,  0.20879221, -0.01603091, -0.18493965],       [ 0.35550284, -0.05491427,  0.17644823, -0.33692163, -0.19445068,        -0.04921806, -0.32028157, -0.25789088, -0.2392104 ,  0.2923655 ,        -0.13487688,  0.07437733, -0.08911327, -0.05451488,  0.26065683,        -0.15244626,  0.16566378,  0.35106057,  0.2891826 , -0.31660333],       [ 0.27196383,  0.06386155, -0.00255299, -0.08238894,  0.10117772,        -0.1466422 , -0.3506349 , -0.17026383,  0.37896347,  0.38407606,         0.11256388,  0.04563713, -0.37850204, -0.02221993,  0.10061872,         0.3318779 ,  0.02184781, -0.2571288 , -0.3272443 , -0.30387318],       [-0.30832642, -0.26533905, -0.17655498, -0.17255014, -0.25053996,        -0.23460795,  0.35648036,  0.25864762,  0.2033844 , -0.1375912 ,        -0.21034713,  0.23340106,  0.07856712,  0.06940717,  0.3642025 ,        -0.10192937,  0.3174355 , -0.16317467,  0.12438256,  0.34762514],       [ 0.29233575, -0.23410663,  0.17361456, -0.35682717,  0.29746616,        -0.13978545,  0.19865829, -0.32398444,  0.27369195,  0.29182   ,         0.28845608,  0.3714251 , -0.00226831, -0.11382625, -0.03799275,        -0.38083866,  0.13849735, -0.17412637,  0.30680603, -0.32791764],       [-0.36856598, -0.01098448,  0.2209788 , -0.01641467,  0.36460286,         0.3742503 , -0.07144001, -0.32689905, -0.2800351 , -0.26420033,        -0.3203124 ,  0.22266299,  0.05407029,  0.20716977,  0.23186374,         0.34451336, -0.3665755 , -0.08111835,  0.02044231,  0.22657269],       [-0.18087737, -0.09122089, -0.3162348 , -0.01350608,  0.2994557 ,         0.00759923,  0.07653233, -0.11245179, -0.06106046,  0.09489083,        -0.34051555, -0.0210776 , -0.3720226 , -0.08034962, -0.3628871 ,        -0.08755568,  0.13865143, -0.13755408,  0.18153298,  0.23439962],       [-0.36453527, -0.3077588 ,  0.06971669,  0.23991793, -0.32902858,        -0.1256682 , -0.37355578, -0.22176625, -0.06080669,  0.12455881,         0.02237046,  0.21177506,  0.05803809, -0.07626435, -0.36375207,         0.13273174,  0.15075874, -0.18664922,  0.20256019,  0.17832053],       [-0.09238327,  0.03065437,  0.04975492,  0.03068706, -0.01132107,         0.04134732,  0.2726786 ,  0.09169459,  0.16609359, -0.26199952,         0.34235936, -0.3293307 , -0.2625829 ,  0.05643666, -0.19363837,        -0.09321746,  0.15029383,  0.1271655 , -0.13643244, -0.1260187 ],       [-0.10736805,  0.08597881, -0.28592098,  0.32719833,  0.25863254,        -0.35738683,  0.28420174,  0.07898697,  0.12083912,  0.24187142,         0.20364356,  0.16368687,  0.3372751 , -0.11902198,  0.29610634,        -0.26228833,  0.26691556,  0.02676412,  0.20875496,  0.2742722 ],       [-0.01111042,  0.01864234, -0.3684872 ,  0.25593793,  0.05572906,        -0.27395982, -0.18536313, -0.28665859,  0.33866453, -0.04277194,         0.31874043,  0.17231691,  0.26513118,  0.2841534 ,  0.38413507,         0.32093495, -0.1821885 ,  0.3448484 ,  0.06886706,  0.05471361],       [ 0.27029324,  0.17785454, -0.3417698 , -0.18585834,  0.13658857,         0.25487036, -0.34464136, -0.31934893, -0.07558686,  0.12984264,        -0.12386304,  0.33101034,  0.2395941 , -0.35495222, -0.38362566,         0.02023152, -0.38369113,  0.10231277, -0.00923318, -0.2564116 ],       [ 0.3738134 , -0.13709581, -0.02727005, -0.38571945, -0.17950383,        -0.08438393, -0.35748094, -0.02563897, -0.26492482, -0.3148442 ,         0.27745587,  0.01215285,  0.0338603 ,  0.22927964, -0.26310933,         0.17490405,  0.15125847,  0.33357888, -0.10504535,  0.09216848],       [-0.18665901,  0.14127249, -0.31025392,  0.3109604 ,  0.11353964,         0.1544854 , -0.0628956 ,  0.2526992 , -0.38535342, -0.35054773,         0.3639174 , -0.22744954,  0.2787813 ,  0.25469422, -0.24284746,         0.2586198 ,  0.38151866,  0.14534372, -0.07336038,  0.35205972],       [-0.0222263 , -0.09051144,  0.19810867,  0.2596696 ,  0.16493648,        -0.15432249, -0.12816939,  0.26651537,  0.33925128,  0.24480599,        -0.20886998, -0.23604779,  0.35640693, -0.1257923 , -0.3385602 ,         0.37019014, -0.34767368, -0.20191407,  0.05838048, -0.3322008 ],       [-0.22849075,  0.31127506, -0.1032331 ,  0.03278631, -0.3802262 ,         0.06519806, -0.10763076, -0.23816115,  0.29874003, -0.17749721,        -0.10582674, -0.03064901,  0.18550068,  0.08624834,  0.09579298,        -0.305739  ,  0.00272122,  0.14033073,  0.22830683, -0.17147864],       [-0.23749031, -0.36036015, -0.15639098,  0.16943222,  0.33908015,        -0.18797807, -0.31251115,  0.13584453, -0.10717931, -0.11736256,         0.17281443,  0.1897279 , -0.35898107,  0.21381551,  0.3051238 ,         0.12489098,  0.29044586, -0.20301346, -0.25790715,  0.04153055],       [-0.25941753, -0.3507824 ,  0.34750968, -0.04910356, -0.1914334 ,        -0.22343925,  0.3420688 ,  0.38251758, -0.09309632,  0.3546936 ,        -0.22427556, -0.24499758,  0.00074324, -0.06633586,  0.1922136 ,         0.11927372, -0.19837731, -0.23528719, -0.26004478, -0.24688683],       [-0.03148285, -0.32766464, -0.2530514 ,  0.1765365 ,  0.26583946,        -0.18146862, -0.20307828, -0.07899943, -0.10167924, -0.05031338,         0.03324467, -0.27283487, -0.3672278 , -0.24607424, -0.15097658,         0.18689764, -0.32162574,  0.10529301, -0.2671068 , -0.29016626],       [ 0.23749584, -0.10357189,  0.03281826,  0.30171496,  0.3568563 ,        -0.27596533,  0.3714081 ,  0.04653817,  0.14261234, -0.18491131,         0.3152057 ,  0.23654068,  0.07070702,  0.0457052 , -0.17505457,         0.15374076, -0.03659964,  0.2212556 , -0.05467528,  0.160887  ]],      dtype=float32)><tf.Variable 'dense_34/bias:0' shape=(20,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,       0., 0., 0.], dtype=float32)><tf.Variable 'dense_35/kernel:0' shape=(20, 1) dtype=float32, numpy=array([[ 0.44639164],       [ 0.46424013],       [-0.36026257],       [-0.19160783],       [-0.236644  ],       [ 0.36841106],       [ 0.5083434 ],       [-0.00797582],       [ 0.25151885],       [-0.51940155],       [-0.03726539],       [-0.15949944],       [-0.2284751 ],       [ 0.4611426 ],       [ 0.2685169 ],       [ 0.1900658 ],       [-0.4574982 ],       [ 0.22935611],       [-0.46032292],       [-0.28261   ]], dtype=float32)><tf.Variable 'dense_35/bias:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>

回答:

model.weights 返回所有权重。您的理解是正确的。 weights[0] 是一个2×20的连接集合,weights[1] 是对应偏置的权重。因为您使用的是 tf.keras.layers.Dense,默认情况下 use_bias=True。所以,weight[1] 指的是第一层Dense层的偏置权重。其他Dense层的情况也一样。尝试将其中任何一层设置为 use_bias=False,您会注意到差异。例如 –

model = keras.Sequential([ keras.layers.Flatten(input_shape=(2,)), keras.layers.Dense(20, activation=tf.nn.relu, use_bias=False), keras.layers.Dense(20, activation=tf.nn.relu), keras.layers.Dense(1) ]) 
len(model.weights) 

这会返回一个包含5个元素的列表。它忽略了第一层的偏置权重。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注