什么是TensorFlow word2vec示例中权重和偏置的用途?

我在尝试理解word2vec示例的工作原理,但我并不完全理解传递给nse_loss函数的权重和偏置的作用。该函数有两个变量输入:权重(加上偏置)和嵌入。

# Look up embeddings for inputs.embeddings = tf.Variable(    tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))embed = tf.nn.embedding_lookup(embeddings, train_inputs)# Construct the variables for the NCE lossnce_weights = tf.Variable(    tf.truncated_normal([vocabulary_size, embedding_size],                        stddev=1.0 / math.sqrt(embedding_size)))nce_biases = tf.Variable(tf.zeros([vocabulary_size]))

两者都是随机初始化的,并且(据我所知)在学习过程中都将被更新。

# Compute the average NCE loss for the batch.loss = tf.reduce_mean(  tf.nn.nce_loss(nce_weights, nce_biases, embed, train_labels,                 num_sampled, vocabulary_size))

我认为它们都应该代表训练后的模型。然而,权重和偏置在后续的相似性计算中从未被使用。相反,只使用了一个组件:

# Compute the cosine similarity between minibatch examples and all embeddings.norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True))normalized_embeddings = embeddings / normvalid_embeddings = tf.nn.embedding_lookup(  normalized_embeddings, valid_dataset)similarity = tf.matmul(  valid_embeddings, normalized_embeddings, transpose_b=True)

那么模型的第二个组件呢?为什么权重和偏置被忽略了?

谢谢你。


回答:

在word2vec中,你想要的是词语的向量表示。为了实现这一点,你可以使用神经网络等方法。因此,你有输入神经元、输出和隐藏层。你通过设置隐藏层的神经元数量与你希望在向量中的维度相同来学习向量表示。每个词有一个输入和一个输出。然后,你训练网络从输出学习输入,但在中间有一个较小的层,你可以将其视为输入在向量中的编码。这就是权重和偏置所在的地方。但你之后不需要它们,测试时使用的是包含词和代表该词的向量的字典。这样做比运行神经网络获取表示更快。这就是你之后看不到它们的原因。

你写的关于余弦距离的最后一段代码是用来知道哪些向量与你计算的向量接近。你有一些词(向量),你进行一些操作(如:国王 – 男人 + 女人),然后你有一个你想转换成结果的向量。这是所有向量之间的余弦函数运行(女王与操作结果向量的距离最小)。

总之,在验证阶段你看不到权重和偏置,因为你不需要它们。你使用在训练中创建的字典。

更新 s0urcer 更好地解释了向量表示是如何创建的。

网络的输入层和输出层代表词语。这意味着如果词不在那里,值为0,如果词在那里,值为1。第一个位置是一个词,第二个是另一个词,依此类推。你有与词同样多的输入/输出神经元。

中间层是上下文,或者是你对词的向量表示。

现在你用句子或一组连续的词来训练网络。从这组词中,你取一个词并将其设置为输入,其他的词是网络的输出。因此,网络基本上学会了一个词如何与其上下文中的其他词相关联。

要获得每个词的向量表示,你将该词的输入神经元设置为1,然后查看上下文层(中间层)的值。这些值就是向量的值。因为所有输入都是0,除了那个词是1,所以这些值是输入神经元与上下文之间的连接的权重。

你之后不使用网络,因为你不需要计算上下文层的全部值,那样会更慢。你只需要在你的字典中查找该词的值即可。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注