model.get_weights() 与 model.trainable_variables 在 Tensorflow 中对比

model.get_weights()model.trainable_variables 在 Tensorflow 中似乎返回相同的值,但数据类型不同。前者返回数组列表,后者返回张量数组。(如果我没有理解错的话)

请解释在哪些情况下使用哪一个更合适?另外,我试图比较它们但没有成功,如果可能的话,请提供一个比较代码的示例。

谢谢你。


回答:

model.get_weights():

  • 它返回层当前的权重,以 NumPy 数组的形式。
  • 此函数返回一个包含与该层相关联的可训练和不可训练权重值的 NumPy 数组列表,这些数组可用于将状态加载到参数相似的层中。

model.trainable_variables:

  • 它返回在层中创建的所有变量,当 trainable = True 时。默认情况下,所有层都设置为 true。

因此,model.get_weights()model.trainable_variables 返回相同的值。请参考这个 gist 以更好地理解。谢谢你!

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注