使用Keras实现对决DQN

我正在尝试实现对决DQN,但是如果我按照这种方式构建神经网络架构,它似乎无法学习。

        X_input = Input(shape=(self.state_size,))        X = X_input        X = Dense(512, input_shape= (self.state_size,), activation="relu")(X_input)        X = Dense(260, activation="relu")(X)        X = Dense(100, activation="relu")(X)        state_value = Dense(1)(X)        state_value = Lambda(lambda v: v, output_shape=(self.action_size,))(state_value)        action_advantage = Dense(self.action_size)(X)        action_advantage = Lambda(lambda a: a[:, :] - K.mean(a[:, :], keepdims=True), output_shape=(self.action_size,))(action_advantage)        X = Add()([state_value, action_advantage])        model = Model(inputs = X_input, outputs = X)        model.compile(loss="mean_squared_error", optimizer=Adam(lr=self.learning_rate))        return model

我在网上搜索了一些代码(这些代码的效果比我的好很多),唯一不同的是

        state_value = Lambda(lambda s: K.expand_dims(s[:, 0],-1), output_shape=(self.action_size,))(state_value)

代码链接 https://github.com/pythonlessons/Reinforcement_Learning/blob/master/03_CartPole-reinforcement-learning_Dueling_DDQN/Cartpole_Double_DDQN.py#L31我不明白为什么我的代码无法学习(虽然它能运行)。而且我不明白为什么他只取了张量中每行第一个值?


回答:

扩展状态值的维度确保在执行Add()操作时,它会被加到每个优势值上。

你也可以这样写:去掉lambda函数,并以如下方式写出Q值的实际计算:

X = (state_value + (action_advantage - tf.math.reduce_mean(action_advantage, axis=1, keepdims=True)))

结果将是相同的,但代码可能会更易读。

因此,你的完整代码将如下所示:

X_input = Input(shape=(self.state_size,))X = X_inputX = Dense(512, input_shape= (self.state_size,), activation="relu")(X_input)X = Dense(260, activation="relu")(X)X = Dense(100, activation="relu")(X)state_value = Dense(1)(X)action_advantage = Dense(self.action_size)(X)X = (state_value + (action_advantage - tf.math.reduce_mean(action_advantage, axis=1, keepdims=True)))model = Model(inputs = X_input, outputs = X)model.compile(loss="mean_squared_error", optimizer=Adam(lr=self.learning_rate))return model

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注