在尝试使用DQN算法玩FrozenLake OpenAI游戏时遇到的错误

我试图让一个非常简单的DQN算法与FrozenLake-v0游戏一起工作,但出现了错误。我明白使用DQN而不是Q-table可能有些过头,但我还是希望它能工作。这里是代码:

import gymimport numpy as npimport tensorflow as tfenv = gym.make("FrozenLake-v0")n_actions = env.action_space.ninput_dim = env.observation_space.nmodel = tf.keras.Sequential() model.add(tf.keras.layers.Dense(64, input_dim = input_dim , activation = 'relu'))model.add(tf.keras.layers.Dense(32, activation = 'relu'))model.add(tf.keras.layers.Dense(n_actions, activation = 'linear'))model.compile(optimizer=tf.keras.optimizers.Adam(), loss = 'mse')def replay(replay_memory, minibatch_size=32):    minibatch = np.random.choice(replay_memory, minibatch_size, replace=True)    s_l =      np.array(list(map(lambda x: x['s'], minibatch)))    a_l =      np.array(list(map(lambda x: x['a'], minibatch)))    r_l =      np.array(list(map(lambda x: x['r'], minibatch)))    sprime_l = np.array(list(map(lambda x: x['sprime'], minibatch)))    done_l   = np.array(list(map(lambda x: x['done'], minibatch)))    qvals_sprime_l = model.predict(sprime_l)    target_f = model.predict(s_l)     for i,(s,a,r,qvals_sprime, done) in enumerate(zip(s_l,a_l,r_l,qvals_sprime_l, done_l)):         if not done:  target = r + gamma * np.max(qvals_sprime)        else:         target = r        target_f[i][a] = target    model.fit(s_l,target_f, epochs=1, verbose=0)    return modeln_episodes = 500gamma = 0.99epsilon = 0.9minibatch_size = 32r_sums = []  replay_memory = []mem_max_size = 100000for n in range(n_episodes):     s = env.reset()    done=False    r_sum = 0    print(s)    while not done:         qvals_s = model.predict(s.reshape(16))        if np.random.random() < epsilon:  a = env.action_space.sample()        else:                             a = np.argmax(qvals_s);         sprime, r, done, info = env.step(a)        r_sum += r         if len(replay_memory) > mem_max_size:            replay_memory.pop(0)        replay_memory.append({"s":s,"a":a,"r":r,"sprime":sprime,"done":done})        s=sprime        model=replay(replay_memory, minibatch_size = minibatch_size)    if epsilon > 0.1:      epsilon -= 0.001    r_sums.append(r_sum)    if n % 100 == 0: print(n)

我遇到的错误是:

Traceback (most recent call last):  File "froz_versuch.py", line 48, in <module>    qvals_s = model.predict(s.reshape(16))ValueError: cannot reshape array of size 1 into shape (16,)

当我尝试将qvals_s = model.predict(s.reshape(16))更改为qvals_s = model.predict(s.reshape(1))时,我得到了以下错误:

ValueError: Input 0 of layer sequential is incompatible with the layer: expected axis -1 of input shape to have value 16 but received input with shape [None, 1]

任何帮助我都会非常感激!


回答:

问题与独热编码有关。我必须对ssprime进行编码,使它们具有16的维度。这个在循环中的更改使其工作了。encode()函数可以移到循环之外,但我现在只是在测试,所以优化随后再进行。这里是解决方案:

 for n in range(n_episodes):     ss = env.reset()    states_total = 16    data = [[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]]    def encode(data, states_total):        targets = np.array(data).reshape(-1)        return np.eye(states_total)[targets]    m = encode(data,states_total)    s = m[ss]    #print(s)    #print(len(s))    done=False    r_sum = 0    while not done:         #env.render()        qvals_s = model.predict(s.reshape(1,-1))        if np.random.random() < epsilon:  a = env.action_space.sample()        else:                             a = np.argmax(qvals_s);         sprime, r, done, info = env.step(a)        r_sum += r        q = encode(data,states_total)        sprime = q[sprime]        if len(replay_memory) > mem_max_size:            replay_memory.pop(0)        replay_memory.append({"s":s,"a":a,"r":r,"sprime":sprime,"done":done})        #s = n[sprime]        s=sprime        model=replay(replay_memory, minibatch_size = minibatch_size)    if epsilon > 0.001:      epsilon -= 0.001    r_sums.append(r_sum)    print(r_sum)    print(epsilon)    if n % 100 == 0: print(n)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注