如何奖励一个在类似超级马里奥兄弟的游戏中前进的智能体?我拥有的唯一数据是分数和生命数,但有没有办法获取智能体的坐标?我正在使用NEAT来训练我的智能体,以下是代码。我目前的奖励机制是让它尽可能获得最高分,而奖励它按下右键是行不通的,因为它会一直撞墙并在计时器耗尽前不断获取奖励。
import retroimport numpy as npimport cv2import neatimport pickleenv = retro.make('SuperMarioWorld-Snes', 'Start.state')imgarray = []xpos_end = 0def eval_genomes(genomes, config): for genome_id, genome in genomes: ob = env.reset() ac = env.action_space.sample() inx, iny, inc = env.observation_space.shape inx = int(inx / 8) iny = int(iny / 8) net = neat.nn.recurrent.RecurrentNetwork.create(genome, config) current_max_fitness = 0 fitness_current = 0 frame = 0 counter = 0 xpos = 0 xpos_max = 0 done = False # cv2.namedWindow("main", cv2.WINDOW_NORMAL) while not done: env.render() frame += 1 # scaledimg = cv2.cvtColor(ob, cv2.COLOR_BGR2RGB) # scaledimg = cv2.resize(scaledimg, (iny, inx)) ob = cv2.resize(ob, (inx, iny)) ob = cv2.cvtColor(ob, cv2.COLOR_BGR2GRAY) ob = np.reshape(ob, (inx, iny)) # cv2.imshow('main', scaledimg) # cv2.waitKey(1) imgarray = np.ndarray.flatten(ob) nnOutput = net.activate(imgarray) for i in range(len(nnOutput)): nnOutput[i] = int(nnOutput[i]) if nnOutput[i] < 0: nnOutput[i] = 0 ob, rew, done, info = env.step(nnOutput) # xpos = info['x'] # xpos_end = info['screen_x_end'] # if xpos > xpos_max: # fitness_current += 1 # xpos_max = xpos # if xpos == xpos_end and xpos > 500: # fitness_current += 100000 # done = True fitness_current += rew print(env.statename) if fitness_current > current_max_fitness: current_max_fitness = fitness_current counter = 0 else: counter += 1 if done or counter == 250: done = True print(genome_id, fitness_current) genome.fitness = fitness_currentconfig = neat.Config(neat.DefaultGenome, neat.DefaultReproduction, neat.DefaultSpeciesSet, neat.DefaultStagnation, 'config.txt')p = neat.Population(config)p.add_reporter(neat.StdOutReporter(True))stats = neat.StatisticsReporter()p.add_reporter(stats)p.add_reporter(neat.Checkpointer(10))winner = p.run(eval_genomes)with open('winner.pkl', 'wb') as output: pickle.dump(winner, output, 1)
回答:
使用 print( retro.__file__ )
我找到了包含 retro
模块的文件夹,并检查了所有子文件夹,发现了包含 SuperMarioWorld
的文件夹
在我的Linux系统上,它位于
/usr/local/lib/python3.8/dist-packages/retro/data/stable/SuperMarioWorld-Snes
这里有一个 data.json
文件,它定义了 retro
如何在 ROM
中查找 score
和 lives
在 OpenAI-Retro-SuperMarioWorld-SNES 中,我找到了 data.json,其中还包含了 x
、y
等信息。
如果我替换 data.json
,那么我可以在代码中获取 info["x"]
。
但我不确定这个文件是否适用于所有版本的 SuperMario
。
我测试了 Super Mario World (Europe) (Rev 1)
,这是我在
找到的,但还有其他版本 – 欧洲、美国、日本。