加载TensorFlow模型后运行前向传播函数

我无法在加载保存的TensorFlow模型后运行前向传播函数。虽然我能够成功提取权重,但是当我尝试向前向传播函数传递新输入时,它会抛出一个“尝试使用未初始化的值”的错误。

我的占位符如下:

x = tf.placeholder('int64', [None, 4], name='input')  # 样本数 x 特征数
y = tf.placeholder('int64', [None, 1], name='output')  # 样本数 x 输出

前向传播函数:

def forwardProp(x, y):
    embedding_mat = tf.get_variable("EM", shape=[total_vocab, e_features], initializer=tf.random_normal_initializer(seed=1))
    # m x words x total_vocab * total_vocab x e_features = m x words x e_features
    # embed_x = tf.tensordot(x, tf.transpose(embedding_mat), axes=[[2], [0]])
    # embed_y = tf.tensordot(y, tf.transpose(embedding_mat), axes=[[2], [0]])
    embed_x = tf.gather(embedding_mat, x)  # m x words x e_features
    embed_y = tf.gather(embedding_mat, y)  # m x words x e_features
    #print("Shape of embed x", embed_x.get_shape())
    W1 = tf.get_variable("W1", shape=[n1, e_features], initializer=tf.random_normal_initializer(seed=1))
    B1 = tf.get_variable("b1", shape=[1, 4, n1], initializer=tf.zeros_initializer())
    # m x words x e_features *  e_features x n1 = m x words x n1
    Z1 = tf.add(tf.tensordot(embed_x, tf.transpose(W1), axes=[[2], [0]]), B1, )
    A1 = tf.nn.tanh(Z1)
    W2 = tf.get_variable("W2", shape=[n2, n1], initializer=tf.random_normal_initializer(seed=1))
    B2 = tf.get_variable("B2", shape=[1, 4, n2], initializer=tf.zeros_initializer())
    # m x words x n1 *  n1 x n2 = m x words x n2
    Z2 = tf.add(tf.tensordot(A1, tf.transpose(W2), axes=[[2], [0]]), B2)
    A2 = tf.nn.tanh(Z2)
    W3 = tf.get_variable("W3", shape=[n3, n2], initializer=tf.random_normal_initializer(seed=1))
    B3 = tf.get_variable("B3", shape=[1, 4, n3], initializer=tf.zeros_initializer())
    # m x words x n2  * n2 x n3 = m x words x n3
    Z3 = tf.add(tf.tensordot(A2, tf.transpose(W3), axes=[[2], [0]]), B3)
    A3 = tf.nn.tanh(Z3)
    # 将 m x words x n3 转换为 m x n3
    x_final = tf.reduce_mean(A3, axis=1)
    y_final = tf.reduce_mean(embed_y, axis=1)
    return x_final, y_final

反向传播函数:

def backProp(X_index, Y_index):
    x_final, y_final = forwardProp(x, y)
    cost = tf.nn.l2_loss(x_final - y_final)
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    total_batches = math.floor(m/batch_size)
    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(epochs):
            batch_start = 0
            for i in range(int(m/batch_size)):
                x_hot = X_index[batch_start: batch_start + batch_size]
                y_hot = Y_index[batch_start: batch_start + batch_size]
                batch_start += batch_size
                _, temp_cost = sess.run([optimizer, cost], feed_dict={x: x_hot, y: y_hot})
                print("Cost at minibatch:  ", i , " and epoch ", epoch, " is ", temp_cost)
            if m % batch_size != 0:
                x_hot = X_index[batch_start: batch_start+m - (batch_size*total_batches)]
                y_hot = Y_index[batch_start: batch_start+m - (batch_size*total_batches)]
                _, temp_cost = sess.run([optimizer, cost], feed_dict={x: x_hot, y: y_hot})
                print("Cost at minibatch: (beyond floor)  and epoch ", epoch, " is ", temp_cost)
        # 保存模型
        save_path = saver.save(sess, "./model_neural_embeddingV1.ckpt")
        print("模型已保存!")

通过调用预测函数重新加载模型:

def predict_search():
    # 初始化变量
    total_features = 4
    extra = len(word_to_indice)
    query = input('请输入您的查询')
    words = word_tokenize(query)
    # 目前,如果字典中没有某个词,会抛出错误
    features = [word_to_indice[w.lower()] for w in words]
    len_features = len(features)
    X_query = []
    Y_query = [[0]]  # 虚拟变量,我们在进行预测时不关心Y查询
    if len_features < total_features:
        features += [extra] * (total_features - len_features)
    elif len_features > total_features:
        features = features[:total_features]
    X_query.append(features)
    X_query = np.array(X_query)
    print(X_query)
    Y_query = np.array(Y_query)
    # 加载模型
    init_global = tf.global_variables_initializer()
    init_local = tf.local_variables_initializer()
    #X_final, Y_final = forwardProp(x, y)
    with tf.Session() as sess:
        sess.run(init_global)
        sess.run(init_local)
        saver = tf.train.import_meta_graph('./model_neural_embeddingV1.ckpt.meta')
        saver.restore(sess, './model_neural_embeddingV1.ckpt')
        print("模型已加载")
        print("已加载的变量是:")
        print(tf.trainable_variables())
        print(sess.graph.get_operations())
        embedMat = sess.run('EM:0')  # 获取词嵌入矩阵
        W1 = sess.run('W1:0')
        b1 = sess.run('b1:0')
        W2 = sess.run('W2:0')
        b2 = sess.run('B2:0')
        print(b2)
        W3 = sess.run('W3:0')
        b3 = sess.run('B3:0')
        **# 这一部分不起作用,调用前向传播会抛出“尝试使用未初始化的值”的错误。**
        X_final = sess.run(forwardProp(x, y), feed_dict={x: X_query, y: Y_query})
        print(X_final)

回答:

您在从元图中加载它们之后,不小心使用forwardProp函数创建了一大堆图变量,这实际上是无意中复制了您的变量。

您应该重构您的代码,遵循在创建会话之前创建图变量的最佳实践。

例如,在一个名为build_graph的函数中创建所有变量。您可以在创建会话之前调用build_graph,但绝不能在之后调用。这将避免此类混淆。

您几乎总是应该避免在sess.run中调用函数,就像您使用以下方式做的那样:

X_final = sess.run(forwardProp(x, y), feed_dict={x: X_query, y: Y_query})

这样做容易引发错误。

请注意,在forwardProp(x, y)中,您正在创建TensorFlow结构,所有权重和偏置项。

但请注意,您已经在以下两行代码中创建了这些变量:

saver = tf.train.import_meta_graph('./model_neural_embeddingV1.ckpt.meta')
saver.restore(sess, './model_neural_embeddingV1.ckpt')

另一个选项,可能是您试图做的事情,是不使用import_meta_graph。您可以创建所有TensorFlow操作和变量,然后运行saver.restore来恢复检查点,这将把检查点数据映射到您已经创建的变量中。

请注意,在TensorFlow中您实际上有两种选择,这有点令人困惑。您最终同时做了两件事(导入包含所有操作和变量的图),以及重新创建图。您必须选择一种方法。

我通常选择第一种方法,不使用import_meta_graph,只是通过调用您的build_graph函数以编程方式重新创建图。然后调用saver.restore来引入检查点。当然,您将在训练和推理时重用您的build_graph函数,以便每次都得到相同的图。

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注