如何在TensorFlow之外使用最终预测值复制TensorFlow的损失值?

我正在尝试学习TensorFlow。我已经构建了一个简单模型的演示版。我注意到,如果我使用最终预测值并与实际值进行比较,我无法得到与TensorFlow相同的损失值。

这个演示模型是一个时间序列回归,试图预测时间序列中的下一个值。模型本身没有什么特别之处,我不是在寻求关于模型构建的建议。我只是想理解TensorFlow在做什么,而这段代码计算出的损失值是我无法复制的。

我在R中使用tensorflow包进行操作,这是一个直接绑定到Python库的包。所以代码很容易移植。

我使用的时间序列值是:

prices = c(104.285380,101.347495,101.357033,102.778283,106.727258,106.841721,104.209071,105.134314,104.733694,101.891194,       101.099492,103.703526,104.495229,107.213726,107.766964,107.881427,104.104147,109.989455,113.413808,111.754094,       113.156266,113.175344,114.043355,114.405821,113.886963,114.643464,116.845936,119.584662,121.097665,121.691375,       122.409572,123.257045,123.003282,124.003971,127.360347,126.565541,123.328865,124.884959,123.012858,123.616144,       123.874695,123.089466,121.049785,121.231728,121.748831,119.230352,117.056607,119.172896,118.349363,119.651694,       121.653071,123.022434,122.088777,120.561411,121.815862,121.317912,118.148267,118.971800,118.023780,121.011481,       119.153744,118.981376,120.006005,121.949926,120.666746,120.274132,121.193425,121.710527,121.471128,120.944449,       121.404096,120.819962,119.460175,122.189325,121.528583,123.166074,124.171550,124.755684,127.025188,125.023811,       123.185225,119.843213,123.482080,123.242681,120.465651,119.709150,119.948549,122.715809,121.465765,121.028250,       121.167678,123.994700,123.821617,125.187049,125.071660,125.062044,126.340935,127.446743,124.638953,126.970765,       126.715948,125.273590,125.518791,124.965887,125.119739,124.388944,123.706228,122.888892,122.523495,123.927390,       123.648534,122.283102,122.042709,122.696577,122.408106,122.965818,121.735006,122.706193,122.148482,123.186979,       122.600420,121.879241,119.744552,120.605159,121.735006,121.581154,121.158062,120.859975,117.859871,115.455941,       118.542587,120.831128,120.783049,121.946551,123.571608,124.638953,126.994804,125.725529,120.408036,120.350342,       119.715705,118.052185,118.638744,118.263731,117.667556,116.638674,113.888579,110.234605,110.965400,110.705776,       111.582500,115.639343,109.621692,111.312044,111.225111,112.007502,113.166600,112.529097,111.089883,108.810324,       102.155170,99.605154,100.204021,105.951215,109.071121,109.428509,108.916574,104.048363,108.510890,106.608038,       105.545531,108.481913,106.395536,108.733051,110.317151,111.379658,112.316595,112.442164,110.037036,109.583056,       111.283066,109.534760,110.423402,111.080224,110.800109,108.607482,105.342689,106.202353,105.844965,106.617697,       107.004063,107.515998,107.004063,105.767692,108.298389,107.796113,107.979637,106.453491,108.047251,107.255201,       107.921682,109.892149,109.882489,111.563182,115.021157,111.350680,110.645562,115.204681,116.421734,115.426842,       117.049579,118.392201,117.841629,116.798441,117.436526,116.961193,113.274931,112.634686,112.256359,108.977526,       110.757603,110.287119,113.779367,115.224769,115.729205,114.225599,115.321776,114.497218,114.283803,114.759136,       113.827870,112.799597,111.751923,115.467287,114.739735,114.691232,112.159352,112.692890,109.792384,109.113336,       107.182899,108.007458,105.718095,102.856393,104.117482,104.020475,105.359170,104.796530,103.622747,105.485279,       104.107781,102.109440,102.196746,99.635764,97.685926,93.563134,94.057869,95.580877,96.968075,94.474998,96.541245,       94.222780,93.766848,93.892957,93.417623,98.384375,96.463639,96.997177,90.623825,91.273771,94.426495,93.543732,       91.652098,93.466127,93.708644,91.696830,92.662368,92.642862,91.940652,91.384737,91.667571,94.252091,95.695522,       93.881481,93.666917,94.486161,92.350274,93.725434,94.369126,94.515420,94.300856,98.045972,98.260536,98.992004,       100.464693,99.352862,98.533617,98.621394,98.670158,99.733225,99.986801,101.995899,103.351553,103.185754,103.302789,       103.293036,104.083021,103.507600,103.058966,102.590827,105.019300,106.852847,106.296931,107.272222,108.374300,       107.096670,108.218254,105.858050,105.975085,106.326190,107.711103,109.271568,109.330085,107.135681,104.824242,       104.268327,104.482891,103.351553,103.068719,102.483545,101.771582,95.402934,92.486815,91.423748,91.326219,92.828167,       91.862629,90.936103,90.981767,91.050455,91.668644,90.775704,88.646385,88.823011,92.120021,91.737332,92.787273,       92.434021,93.434899,94.622215,96.064657,97.752412,98.527602,98.468727,97.987913,96.614159,95.888032,96.084282,       96.780972,97.173473,97.085160,97.781850,96.977222,95.515156,95.632906,95.318905,95.721219,93.542837,93.317149,       94.111964,93.758713,94.298402,91.649019,90.314515,91.835457,92.630272,93.807775,94.092339,93.209211,93.739088,       94.141401,94.867529,95.161904,95.593656,95.053967,96.937972,96.928160,97.958475,97.997725,98.086038,97.565974,       96.810409,95.515156,94.857716,101.019984,102.383926,102.256363,104.061868,102.521301,103.806742,103.885243,106.032880,       106.910897,107.344972,106.545878,106.476821,106.723455,108.005951,107.907298,107.749452,107.611337,107.887567,       107.049012,107.384434,106.575474,106.121668,105.500150,105.381766,104.572806,104.671460,105.292978,106.279514,       106.249917,106.901031,104.099269,101.741448,104.020346,106.496551,110.265119,114.013955,113.372707,112.050749,       112.040883,112.021153,113.076746,111.192462,111.360173,111.567346,112.415767,110.669598,111.527885,111.005021,       111.478558,111.527885,112.356575,112.524286,114.487492,114.734126,115.760124,115.404971,116.046219,115.967296,       115.888373,115.543086,115.483894,115.030087,116.065950,116.657871,114.033686,112.938631,112.188864,112.011287,       109.988889,110.087542,108.351239,107.931825,109.488725,110.133301,109.954803,106.890586,107.525246,104.827942,       106.216260,109.072229,109.032563,109.141645,110.797711,110.867126,110.301883,110.857210,110.639046,110.529963,       109.597807,108.576401,108.982980,108.199572,109.032563,110.103551,111.184456,112.999187,112.354610,114.228840,       114.228840,114.853583,115.002331,115.666741,115.974154,116.083236,115.319661,115.547742,116.281568,115.785740,       115.755990,114.853583,115.180830,115.051914,115.636991,116.926144,117.997132,118.116131,118.750791,118.254963,       118.046715,118.998705,118.988788,118.780540,118.998705,119.078037,118.968955,120.863018,120.922517,120.932434,       120.615104,120.337440,127.675694,127.457529,128.002940,129.202844,130.432497,130.938241,131.315071,131.581537,       132.746769,134.469718,134.957721,134.793393,135.166865,136.142871,136.551200,135.973564,136.103034,136.371934,       136.431689,139.220278,138.393660,139.210318,138.772112,138.951378,138.433497,138.114801,138.572927,138.632682,       138.423538,139.887547,140.116610,139.419462,140.883471,139.270074,140.843634,140.345672,140.066813,140.305835,       143.213935,143.532630,143.343405,143.074505,143.114342,144.179981,143.433038,143.074505,142.755809,142.586502,       141.052778,141.222086,140.475142,141.251963,140.624531,140.106650,141.859477,141.690170,143.054587,143.950919,       143.065343,143.203975,143.064546,146.002523,146.908814,146.460648,145.932808,148.352905,152.376439,153.332527,       152.635380,153.322568,156.100000,155.700000,155.470000,150.250000,152.540000,152.960000,153.990000,153.800000,       153.340000,153.870000,153.610000,153.670000,152.760000,153.180000,155.450000,153.930000,154.450000,155.370000,       154.990000,148.980000,145.320000,146.590000)

我使用的代码是

#split into train and testv.train = prices[1:500]v.test = prices[-(1:500)]sampleSize = 30getJoinedLaggedData = function(dataVector){  data.len = length(dataVector)  result = NULL  for (i in seq(1, data.len - sampleSize + 1, by = 1)) {    result = rbind(result, dataVector[seq(i, i + sampleSize - 1, by = 1)])  }  list(input=result[,-sampleSize], output = result[,sampleSize])}m.train = getJoinedLaggedData(v.train)#build the modelindata = tf$placeholder(tf$float32, shape(471, sampleSize - 1))outData = tf$placeholder(tf$float32, shape(471))w1 <- tf$Variable(tf$truncated_normal(shape(sampleSize - 1, 300),stddev = 1.0 / sqrt(sampleSize)))b1 = tf$Variable(tf$zeros(shape(300)))l1 <- tf$matmul(indata, w1) + b1w2 <- tf$Variable(tf$truncated_normal(shape(300, 50),stddev = 1.0 / sqrt(300)))b2 <- tf$Variable(tf$zeros(shape(50)))l2 <- tf$matmul(l1, w2) + b2w3 <- tf$Variable(tf$truncated_normal(shape(50, 1),stddev = 1.0 / sqrt(50)))b3 <- tf$Variable(tf$zeros(shape(1)))pred <- tf$matmul(l2, w3) + b3#loss functionloss <- tf$reduce_mean(tf$abs(tf$subtract(x = pred, y = outData)))#traineroptimizer <- tf$train$GradientDescentOptimizer(0.00003)train.op <- optimizer$minimize(loss)#run the modelinit = tf$global_variables_initializer()sess = tf$Session()sess$run(init)for (i in 1:1000){  values = sess$run(list(train.op, loss, pred), feed_dict = dict(indata = m.train$input, outData = m.train$output))  loss_value = values[[2]]  pred_value = values[[3]]  print(loss_value)}myLoss = mean(abs(pred_value - m.train$output))print(sprintf("Tensorflow Loss: %f, My Loss: %f", loss_value, myLoss))

当我运行这段代码时,TensorFlow计算出的损失值约为11.71,而我计算出的损失值约为5.03

我尝试的其他损失函数也出现了这种情况。

任何建议都将不胜感激!


回答:

差异的原因在于m.train$output,TensorFlow将其读取为向量而不是矩阵。

我已经按照要求修改了代码。请检查

prices = c(104.285380,101.347495,101.357033,102.778283,106.727258,106.841721,104.209071,105.134314,104.733694,101.891194,       101.099492,103.703526,104.495229,107.213726,107.766964,107.881427,104.104147,109.989455,113.413808,111.754094,       113.156266,113.175344,114.043355,114.405821,113.886963,114.643464,116.845936,119.584662,121.097665,121.691375,       122.409572,123.257045,123.003282,124.003971,127.360347,126.565541,123.328865,124.884959,123.012858,123.616144,       123.874695,123.089466,121.049785,121.231728,121.748831,119.230352,117.056607,119.172896,118.349363,119.651694,       121.653071,123.022434,122.088777,120.561411,121.815862,121.317912,118.148267,118.971800,118.023780,121.011481,       119.153744,118.981376,120.006005,121.949926,120.666746,120.274132,121.193425,121.710527,121.471128,120.944449,       121.404096,120.819962,119.460175,122.189325,121.528583,123.166074,124.171550,124.755684,127.025188,125.023811,       123.185225,119.843213,123.482080,123.242681,120.465651,119.709150,119.948549,122.715809,121.465765,121.028250,       121.167678,123.994700,123.821617,125.187049,125.071660,125.062044,126.340935,127.446743,124.638953,126.970765,       126.715948,125.273590,125.518791,124.965887,125.119739,124.388944,123.706228,122.888892,122.523495,123.927390,       123.648534,122.283102,122.042709,122.696577,122.408106,122.965818,121.735006,122.706193,122.148482,123.186979,       122.600420,121.879241,119.744552,120.605159,121.735006,121.581154,121.158062,120.859975,117.859871,115.455941,       118.542587,120.831128,120.783049,121.946551,123.571608,124.638953,126.994804,125.725529,120.408036,120.350342,       119.715705,118.052185,118.638744,118.263731,117.667556,116.638674,113.888579,110.234605,110.965400,110.705776,       111.582500,115.639343,109.621692,111.312044,111.225111,112.007502,113.166600,112.529097,111.089883,108.810324,       102.155170,99.605154,100.204021,105.951215,109.071121,109.428509,108.916574,104.048363,108.510890,106.608038,       105.545531,108.481913,106.395536,108.733051,110.317151,111.379658,112.316595,112.442164,110.037036,109.583056,       111.283066,109.534760,110.423402,111.080224,110.800109,108.607482,105.342689,106.202353,105.844965,106.617697,       107.004063,107.515998,107.004063,105.767692,108.298389,107.796113,107.979637,106.453491,108.047251,107.255201,       107.921682,109.892149,109.882489,111.563182,115.021157,111.350680,110.645562,115.204681,116.421734,115.426842,       117.049579,118.392201,117.841629,116.798441,117.436526,116.961193,113.274931,112.634686,112.256359,108.977526,       110.757603,110.287119,113.779367,115.224769,115.729205,114.225599,115.321776,114.497218,114.283803,114.759136,       113.827870,112.799597,111.751923,115.467287,114.739735,114.691232,112.159352,112.692890,109.792384,109.113336,       107.182899,108.007458,105.718095,102.856393,104.117482,104.020475,105.359170,104.796530,103.622747,105.485279,       104.107781,102.109440,102.196746,99.635764,97.685926,93.563134,94.057869,95.580877,96.968075,94.474998,96.541245,       94.222780,93.766848,93.892957,93.417623,98.384375,96.463639,96.997177,90.623825,91.273771,94.426495,93.543732,       91.652098,93.466127,93.708644,91.696830,92.662368,92.642862,91.940652,91.384737,91.667571,94.252091,95.695522,       93.881481,93.666917,94.486161,92.350274,93.725434,94.369126,94.515420,94.300856,98.045972,98.260536,98.992004,       100.464693,99.352862,98.533617,98.621394,98.670158,99.733225,99.986801,101.995899,103.351553,103.185754,103.302789,       103.293036,104.083021,103.507600,103.058966,102.590827,105.019300,106.852847,106.296931,107.272222,108.374300,       107.096670,108.218254,105.858050,105.975085,106.326190,107.711103,109.271568,109.330085,107.135681,104.824242,       104.268327,104.482891,103.351553,103.068719,102.483545,101.771582,95.402934,92.486815,91.423748,91.326219,92.828167,       91.862629,90.936103,90.981767,91.050455,91.668644,90.775704,88.646385,88.823011,92.120021,91.737332,92.787273,       92.434021,93.434899,94.622215,96.064657,97.752412,98.527602,98.468727,97.987913,96.614159,95.888032,96.084282,       96.780972,97.173473,97.085160,97.781850,96.977222,95.515156,95.632906,95.318905,95.721219,93.542837,93.317149,       94.111964,93.758713,94.298402,91.649019,90.314515,91.835457,92.630272,93.807775,94.092339,93.209211,93.739088,       94.141401,94.867529,95.161904,95.593656,95.053967,96.937972,96.928160,97.958475,97.997725,98.086038,97.565974,       96.810409,95.515156,94.857716,101.019984,102.383926,102.256363,104.061868,102.521301,103.806742,103.885243,106.032880,       106.910897,107.344972,106.545878,106.476821,106.723455,108.005951,107.907298,107.749452,107.611337,107.887567,       107.049012,107.384434,106.575474,106.121668,105.500150,105.381766,104.572806,104.671460,105.292978,106.279514,       106.249917,106.901031,104.099269,101.741448,104.020346,106.496551,110.265119,114.013955,113.372707,112.050749,       112.040883,112.021153,113.076746,111.192462,111.360173,111.567346,112.415767,110.669598,111.527885,111.005021,       111.478558,111.527885,112.356575,112.524286,114.487492,114.734126,115.760124,115.404971,116.046219,115.967296,       115.888373,115.543086,115.483894,115.030087,116.065950,116.657871,114.033686,112.938631,112.188864,112.011287,       109.988889,110.087542,108.351239,107.931825,109.488725,110.133301,109.954803,106.890586,107.525246,104.827942,       106.216260,109.072229,109.032563,109.141645,110.797711,110.867126,110.301883,110.857210,110.639046,110.529963,       109.597807,108.576401,108.982980,108.199572,109.032563,110.103551,111.184456,112.999187,112.354610,114.228840,       114.228840,114.853583,115.002331,115.666741,115.974154,116.083236,115.319661,115.547742,116.281568,115.785740,       115.755990,114.853583,115.180830,115.051914,115.636991,116.926144,117.997132,118.116131,118.750791,118.254963,       118.046715,118.998705,118.988788,118.780540,118.998705,119.078037,118.968955,120.863018,120.922517,120.932434,       120.615104,120.337440,127.675694,127.457529,128.002940,129.202844,130.432497,130.938241,131.315071,131.581537,       132.746769,134.469718,134.957721,134.793393,135.166865,136.142871,136.551200,135.973564,136.103034,136.371934,       136.431689,139.220278,138.393660,139.210318,138.772112,138.951378,138.433497,138.114801,138.572927,138.632682,       138.423538,139.887547,140.116610,139.419462,140.883471,139.270074,140.843634,140.345672,140.066813,140.305835,       143.213935,143.532630,143.343405,143.074505,143.114342,144.179981,143.433038,143.074505,142.755809,142.586502,       141.052778,141.222086,140.475142,141.251963,140.624531,140.106650,141.859477,141.690170,143.054587,143.950919,       143.065343,143.203975,143.064546,146.002523,146.908814,146.460648,145.932808,148.352905,152.376439,153.332527,       152.635380,153.322568,156.100000,155.700000,155.470000,150.250000,152.540000,152.960000,153.990000,153.800000,       153.340000,153.870000,153.610000,153.670000,152.760000,153.180000,155.450000,153.930000,154.450000,155.370000,       154.990000,148.980000,145.320000,146.590000)v.train = prices[1:500]v.test = prices[-(1:500)]sampleSize = 30getJoinedLaggedData = function(dataVector){  data.len = length(dataVector)  result = NULL  for (i in seq(1, data.len - sampleSize + 1, by = 1)) {    result = rbind(result, dataVector[seq(i, i + sampleSize - 1, by = 1)])  }  list(input=matrix(result[,-sampleSize], ncol = sampleSize - 1), output = matrix(result[,sampleSize]))}m.train = getJoinedLaggedData(v.train)#build the modelindata = tf$placeholder(tf$float32, shape(471, sampleSize - 1))outData = tf$placeholder(tf$float32, shape(471, 1))w1 <- tf$Variable(tf$truncated_normal(shape(sampleSize - 1, 300),stddev = 1.0 / sqrt(sampleSize)))b1 = tf$Variable(tf$zeros(shape(300)))l1 <- tf$matmul(indata, w1) + b1w2 <- tf$Variable(tf$truncated_normal(shape(300, 50),stddev = 1.0 / sqrt(300)))b2 <- tf$Variable(tf$zeros(shape(50)))l2 <- tf$matmul(l1, w2) + b2w3 <- tf$Variable(tf$truncated_normal(shape(50, 1),stddev = 1.0 / sqrt(50)))b3 <- tf$Variable(tf$zeros(shape(1)))pred <- tf$matmul(l2, w3) + b3#loss functionloss <- tf$reduce_mean(tf$abs(tf$sub(x = pred, y = outData)))#traineroptimizer <- tf$train$GradientDescentOptimizer(0.00003)train.op <- optimizer$minimize(loss)#run the modelinit = tf$initialize_all_variables()sess = tf$Session()sess$run(init)for (i in 1:1000){  values = sess$run(list(train.op, loss, pred), feed_dict = dict(indata = m.train$input, outData = m.train$output))  loss_value = values[[2]]  pred_value = values[[3]]  print(loss_value)}myLoss = mean(abs(pred_value - m.train$output))print(sprintf("Tensorflow Loss: %f, My Loss: %f", loss_value, myLoss))

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注