使用Keras构建对决Q网络

我对Keras还比较陌生,目前正在尝试构建一个对决Q网络来训练一个KI。我找到了一段代码片段来构建一个模型,看起来竟然能工作。但我完全不明白其中的原理,因为我不太熟悉Keras中的lambda表达式。能有人解释一下在下面的模型中lambda层的创建是如何工作的吗?非常感谢您的提前帮助!

def build_model():    model = Sequential()    model.add(Dense(units=16, activation='relu', input_dim = 2))    model.add(Dense(units=32, activation='relu'))    model.add(Dense(units=9, activation='relu'))    #我完全不明白下面的层是如何工作的:    model.add(Lambda(lambda i: K.expand_dims(i[:,0],-1) + i[:,1:] - K.mean(i[:,1:], keepdims=True), output_shape=(8,)))    model.add(Dense(units=8, activation='linear'))    model.compile(loss='mse',    optimizer = RMSprop(lr=0.001) )    return model

回答:

我对你的具体研究领域不熟悉,但我可以告诉你这个层在做什么。Lambda层是当你想要对输入执行自定义操作,而这些操作不是Keras预定义的时使用的。具体来说,你想对进入该层的张量应用一些Keras尚未处理的自定义操作。

进入Lambda层的输入是一个匿名函数,其中输入是进入该层的张量。不过请注意,你也可以在这个层中指定任何函数或操作,而不一定是匿名函数……只要它能对输入张量进行操作并产生一个输出即可。然后你定义你想要对这个输入张量执行的操作,并创建一个对应的输出张量,供下一层使用。当然,这种行为假设这是一个前馈网络,这也是我在这里看到的。你可以将匿名函数视为一次性函数,用于执行操作,但一旦你指定了对输入张量做什么后,你就不再需要它们了。

lambda i因此表示你正在创建一个匿名函数,Lambda层将对定义为i的输入张量进行操作。K.expand_dims确保为广播目的添加单例维度。在这种情况下,我们希望取输入张量的第一列i[:,0],它成为一个一维数组,并确保输入张量是一个具有单列的二维数组(即从N,数组变为N x 1数组)。-1参数是你想要扩展的轴。将其设置为-1只是扩展最后一个维度,在这种情况下,最后一个维度是第一个(也是唯一一个)维度。

如果你不习惯广播,添加到这个扩展数组的操作有点难以理解,但一旦你掌握了,它是计算中最强大的机制之一。i[:,1:]切片输入张量,使我们考虑从第二列到最后的张量。在幕后,将这个切片张量与扩展的单列i[:,0]相加意味着这一列会被复制,并分别添加到i[:,1:]中的每一列。

例如,如果i[:,0][1, 2, 3, 4],而i[:,1:][[4, 4, 4, 4], [5, 5, 5, 5], [6, 6, 6, 6],执行K.expand_dims(i[:,0], -1) + i[:,1:]的结果是[[5, 6, 7, 8], [6, 7, 8, 9], [7, 8, 9, 10]]

最后一块拼图是这个:K.mean(i[:,1:], keepdims=True)。我们取K.expand_dims(i[:,0], -1) + i[:,1:]然后用K.mean(i[:,1:], keepdims=True)减去它。在这种情况下,K.mean将找到从第二列开始的所有行中所有值的平均值。这是操作的默认行为。根据你如何使用K.mean,一个或多个维度可能会丢失。K.mean的另一个输入是axis,它允许你指定你想要分析张量中哪个维度的平均值。例如,如果你使用axis=0,这将找到每列的平均值。这将减少到一个一维张量值。使用keepdims关键字,如果你指定keepdims=True,这将确保张量仍然是二维的,列数为1(即一个N x 1张量而不是N,张量)。默认行为是false

因此,通过执行K.mean操作,我们确保最终结果是1 x 1,并从K.expand_dims(i[:,0],-1) + i[:,1:]结果的每个值中减去这个值。这再次是由于广播成为可能的。

最后,我们确保这个操作的输出形状给出一个大小为8的一维张量。

tl;dr

这个操作是一个自定义操作,我们取输入张量的第一列,将其加到从第二列开始的所有其他列上,然后用从第二列开始的所有其他列的平均值减去这个结果的每个值。此外,我们约束张量的输出大小,使其为一维,大小为8。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注