我尝试使用tensorflow创建模型。当我执行时,显示了
其他文件在这个链接——- github.com/llSourcell/tensorflow_chatbot
def train(): enc_train, dec_train=data_utils.prepare_custom_data( gConfig['working_directory']) train_set = read_data(enc_train,dec_train)def seq2seq_f(encoder_inputs,decoder_inputs,do_decode): return tf.nn.seq2seq.embedding_attention_seq2seq( encoder_inputs,decoder_inputs, cell, num_encoder_symbols=source_vocab_size, num_decoder_symbols=target_vocab_size, embedding_size=size, output_projection=output_projection, feed_previous=do_decode)with tf.Session(config=config) as sess: model = create_model(sess,False) while True: sess.run(model) checkpoint_path = os.path.join(gConfig['working_directory'],'seq2seq.ckpt') model.saver.save(sess, checkpoint_path, global_step=model.global_step)
除了这个,我使用的其他python文件都在评论部分指定的github链接中
这是execute.py文件中定义create_model的代码
def create_model(session, forward_only): """创建模型并初始化或加载参数""" model = seq2seq_model.Seq2SeqModel( gConfig['enc_vocab_size'], gConfig['dec_vocab_size'], _buckets, gConfig['layer_size'], gConfig['num_layers'], gConfig['max_gradient_norm'], gConfig['batch_size'], gConfig['learning_rate'], gConfig['learning_rate_decay_factor'], forward_only=forward_only) if 'pretrained_model' in gConfig: model.saver.restore(session,gConfig['pretrained_model']) return model ckpt = tf.train.get_checkpoint_state(gConfig['working_directory']) # 最近版本的tensorflow中检查点文件名已更改 checkpoint_suffix = "" if tf.__version__ > "0.12": checkpoint_suffix = ".index" if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path + checkpoint_suffix): print("从 %s 读取模型参数" % ckpt.model_checkpoint_path) model.saver.restore(session, ckpt.model_checkpoint_path) else: print("使用新参数创建模型。") session.run(tf.initialize_all_variables()) return model
回答:
看起来你复制了代码,但没有进行结构化。如果create_model()
在另一个文件中定义,那么你需要导入它。你这样做了吗?(即from file_with_methods import create_model
)。如果你希望我们帮助,你应该考虑编辑你的帖子并添加更多的代码。
替代方案:你也可以克隆你评论中分享的github仓库,然后只需更改execution.py
文件中你想更改的内容。这样你可以保持所有者使用的“层次结构”,并且可以在需要的地方添加自己的代码。