我正在尝试学习TensorFlow。我已经构建了一个简单模型的演示版。我注意到,如果我使用最终预测值并与实际值进行比较,我无法得到与TensorFlow相同的损失值。
这个演示模型是一个时间序列回归,试图预测时间序列中的下一个值。模型本身没有什么特别之处,我不是在寻求关于模型构建的建议。我只是想理解TensorFlow在做什么,而这段代码计算出的损失值是我无法复制的。
我在R中使用tensorflow包进行操作,这是一个直接绑定到Python库的包。所以代码很容易移植。
我使用的时间序列值是:
prices = c(104.285380,101.347495,101.357033,102.778283,106.727258,106.841721,104.209071,105.134314,104.733694,101.891194, 101.099492,103.703526,104.495229,107.213726,107.766964,107.881427,104.104147,109.989455,113.413808,111.754094, 113.156266,113.175344,114.043355,114.405821,113.886963,114.643464,116.845936,119.584662,121.097665,121.691375, 122.409572,123.257045,123.003282,124.003971,127.360347,126.565541,123.328865,124.884959,123.012858,123.616144, 123.874695,123.089466,121.049785,121.231728,121.748831,119.230352,117.056607,119.172896,118.349363,119.651694, 121.653071,123.022434,122.088777,120.561411,121.815862,121.317912,118.148267,118.971800,118.023780,121.011481, 119.153744,118.981376,120.006005,121.949926,120.666746,120.274132,121.193425,121.710527,121.471128,120.944449, 121.404096,120.819962,119.460175,122.189325,121.528583,123.166074,124.171550,124.755684,127.025188,125.023811, 123.185225,119.843213,123.482080,123.242681,120.465651,119.709150,119.948549,122.715809,121.465765,121.028250, 121.167678,123.994700,123.821617,125.187049,125.071660,125.062044,126.340935,127.446743,124.638953,126.970765, 126.715948,125.273590,125.518791,124.965887,125.119739,124.388944,123.706228,122.888892,122.523495,123.927390, 123.648534,122.283102,122.042709,122.696577,122.408106,122.965818,121.735006,122.706193,122.148482,123.186979, 122.600420,121.879241,119.744552,120.605159,121.735006,121.581154,121.158062,120.859975,117.859871,115.455941, 118.542587,120.831128,120.783049,121.946551,123.571608,124.638953,126.994804,125.725529,120.408036,120.350342, 119.715705,118.052185,118.638744,118.263731,117.667556,116.638674,113.888579,110.234605,110.965400,110.705776, 111.582500,115.639343,109.621692,111.312044,111.225111,112.007502,113.166600,112.529097,111.089883,108.810324, 102.155170,99.605154,100.204021,105.951215,109.071121,109.428509,108.916574,104.048363,108.510890,106.608038, 105.545531,108.481913,106.395536,108.733051,110.317151,111.379658,112.316595,112.442164,110.037036,109.583056, 111.283066,109.534760,110.423402,111.080224,110.800109,108.607482,105.342689,106.202353,105.844965,106.617697, 107.004063,107.515998,107.004063,105.767692,108.298389,107.796113,107.979637,106.453491,108.047251,107.255201, 107.921682,109.892149,109.882489,111.563182,115.021157,111.350680,110.645562,115.204681,116.421734,115.426842, 117.049579,118.392201,117.841629,116.798441,117.436526,116.961193,113.274931,112.634686,112.256359,108.977526, 110.757603,110.287119,113.779367,115.224769,115.729205,114.225599,115.321776,114.497218,114.283803,114.759136, 113.827870,112.799597,111.751923,115.467287,114.739735,114.691232,112.159352,112.692890,109.792384,109.113336, 107.182899,108.007458,105.718095,102.856393,104.117482,104.020475,105.359170,104.796530,103.622747,105.485279, 104.107781,102.109440,102.196746,99.635764,97.685926,93.563134,94.057869,95.580877,96.968075,94.474998,96.541245, 94.222780,93.766848,93.892957,93.417623,98.384375,96.463639,96.997177,90.623825,91.273771,94.426495,93.543732, 91.652098,93.466127,93.708644,91.696830,92.662368,92.642862,91.940652,91.384737,91.667571,94.252091,95.695522, 93.881481,93.666917,94.486161,92.350274,93.725434,94.369126,94.515420,94.300856,98.045972,98.260536,98.992004, 100.464693,99.352862,98.533617,98.621394,98.670158,99.733225,99.986801,101.995899,103.351553,103.185754,103.302789, 103.293036,104.083021,103.507600,103.058966,102.590827,105.019300,106.852847,106.296931,107.272222,108.374300, 107.096670,108.218254,105.858050,105.975085,106.326190,107.711103,109.271568,109.330085,107.135681,104.824242, 104.268327,104.482891,103.351553,103.068719,102.483545,101.771582,95.402934,92.486815,91.423748,91.326219,92.828167, 91.862629,90.936103,90.981767,91.050455,91.668644,90.775704,88.646385,88.823011,92.120021,91.737332,92.787273, 92.434021,93.434899,94.622215,96.064657,97.752412,98.527602,98.468727,97.987913,96.614159,95.888032,96.084282, 96.780972,97.173473,97.085160,97.781850,96.977222,95.515156,95.632906,95.318905,95.721219,93.542837,93.317149, 94.111964,93.758713,94.298402,91.649019,90.314515,91.835457,92.630272,93.807775,94.092339,93.209211,93.739088, 94.141401,94.867529,95.161904,95.593656,95.053967,96.937972,96.928160,97.958475,97.997725,98.086038,97.565974, 96.810409,95.515156,94.857716,101.019984,102.383926,102.256363,104.061868,102.521301,103.806742,103.885243,106.032880, 106.910897,107.344972,106.545878,106.476821,106.723455,108.005951,107.907298,107.749452,107.611337,107.887567, 107.049012,107.384434,106.575474,106.121668,105.500150,105.381766,104.572806,104.671460,105.292978,106.279514, 106.249917,106.901031,104.099269,101.741448,104.020346,106.496551,110.265119,114.013955,113.372707,112.050749, 112.040883,112.021153,113.076746,111.192462,111.360173,111.567346,112.415767,110.669598,111.527885,111.005021, 111.478558,111.527885,112.356575,112.524286,114.487492,114.734126,115.760124,115.404971,116.046219,115.967296, 115.888373,115.543086,115.483894,115.030087,116.065950,116.657871,114.033686,112.938631,112.188864,112.011287, 109.988889,110.087542,108.351239,107.931825,109.488725,110.133301,109.954803,106.890586,107.525246,104.827942, 106.216260,109.072229,109.032563,109.141645,110.797711,110.867126,110.301883,110.857210,110.639046,110.529963, 109.597807,108.576401,108.982980,108.199572,109.032563,110.103551,111.184456,112.999187,112.354610,114.228840, 114.228840,114.853583,115.002331,115.666741,115.974154,116.083236,115.319661,115.547742,116.281568,115.785740, 115.755990,114.853583,115.180830,115.051914,115.636991,116.926144,117.997132,118.116131,118.750791,118.254963, 118.046715,118.998705,118.988788,118.780540,118.998705,119.078037,118.968955,120.863018,120.922517,120.932434, 120.615104,120.337440,127.675694,127.457529,128.002940,129.202844,130.432497,130.938241,131.315071,131.581537, 132.746769,134.469718,134.957721,134.793393,135.166865,136.142871,136.551200,135.973564,136.103034,136.371934, 136.431689,139.220278,138.393660,139.210318,138.772112,138.951378,138.433497,138.114801,138.572927,138.632682, 138.423538,139.887547,140.116610,139.419462,140.883471,139.270074,140.843634,140.345672,140.066813,140.305835, 143.213935,143.532630,143.343405,143.074505,143.114342,144.179981,143.433038,143.074505,142.755809,142.586502, 141.052778,141.222086,140.475142,141.251963,140.624531,140.106650,141.859477,141.690170,143.054587,143.950919, 143.065343,143.203975,143.064546,146.002523,146.908814,146.460648,145.932808,148.352905,152.376439,153.332527, 152.635380,153.322568,156.100000,155.700000,155.470000,150.250000,152.540000,152.960000,153.990000,153.800000, 153.340000,153.870000,153.610000,153.670000,152.760000,153.180000,155.450000,153.930000,154.450000,155.370000, 154.990000,148.980000,145.320000,146.590000)
我使用的代码是
#split into train and testv.train = prices[1:500]v.test = prices[-(1:500)]sampleSize = 30getJoinedLaggedData = function(dataVector){ data.len = length(dataVector) result = NULL for (i in seq(1, data.len - sampleSize + 1, by = 1)) { result = rbind(result, dataVector[seq(i, i + sampleSize - 1, by = 1)]) } list(input=result[,-sampleSize], output = result[,sampleSize])}m.train = getJoinedLaggedData(v.train)#build the modelindata = tf$placeholder(tf$float32, shape(471, sampleSize - 1))outData = tf$placeholder(tf$float32, shape(471))w1 <- tf$Variable(tf$truncated_normal(shape(sampleSize - 1, 300),stddev = 1.0 / sqrt(sampleSize)))b1 = tf$Variable(tf$zeros(shape(300)))l1 <- tf$matmul(indata, w1) + b1w2 <- tf$Variable(tf$truncated_normal(shape(300, 50),stddev = 1.0 / sqrt(300)))b2 <- tf$Variable(tf$zeros(shape(50)))l2 <- tf$matmul(l1, w2) + b2w3 <- tf$Variable(tf$truncated_normal(shape(50, 1),stddev = 1.0 / sqrt(50)))b3 <- tf$Variable(tf$zeros(shape(1)))pred <- tf$matmul(l2, w3) + b3#loss functionloss <- tf$reduce_mean(tf$abs(tf$subtract(x = pred, y = outData)))#traineroptimizer <- tf$train$GradientDescentOptimizer(0.00003)train.op <- optimizer$minimize(loss)#run the modelinit = tf$global_variables_initializer()sess = tf$Session()sess$run(init)for (i in 1:1000){ values = sess$run(list(train.op, loss, pred), feed_dict = dict(indata = m.train$input, outData = m.train$output)) loss_value = values[[2]] pred_value = values[[3]] print(loss_value)}myLoss = mean(abs(pred_value - m.train$output))print(sprintf("Tensorflow Loss: %f, My Loss: %f", loss_value, myLoss))
当我运行这段代码时,TensorFlow计算出的损失值约为11.71,而我计算出的损失值约为5.03
我尝试的其他损失函数也出现了这种情况。
任何建议都将不胜感激!
回答:
差异的原因在于m.train$output,TensorFlow将其读取为向量而不是矩阵。
我已经按照要求修改了代码。请检查
prices = c(104.285380,101.347495,101.357033,102.778283,106.727258,106.841721,104.209071,105.134314,104.733694,101.891194, 101.099492,103.703526,104.495229,107.213726,107.766964,107.881427,104.104147,109.989455,113.413808,111.754094, 113.156266,113.175344,114.043355,114.405821,113.886963,114.643464,116.845936,119.584662,121.097665,121.691375, 122.409572,123.257045,123.003282,124.003971,127.360347,126.565541,123.328865,124.884959,123.012858,123.616144, 123.874695,123.089466,121.049785,121.231728,121.748831,119.230352,117.056607,119.172896,118.349363,119.651694, 121.653071,123.022434,122.088777,120.561411,121.815862,121.317912,118.148267,118.971800,118.023780,121.011481, 119.153744,118.981376,120.006005,121.949926,120.666746,120.274132,121.193425,121.710527,121.471128,120.944449, 121.404096,120.819962,119.460175,122.189325,121.528583,123.166074,124.171550,124.755684,127.025188,125.023811, 123.185225,119.843213,123.482080,123.242681,120.465651,119.709150,119.948549,122.715809,121.465765,121.028250, 121.167678,123.994700,123.821617,125.187049,125.071660,125.062044,126.340935,127.446743,124.638953,126.970765, 126.715948,125.273590,125.518791,124.965887,125.119739,124.388944,123.706228,122.888892,122.523495,123.927390, 123.648534,122.283102,122.042709,122.696577,122.408106,122.965818,121.735006,122.706193,122.148482,123.186979, 122.600420,121.879241,119.744552,120.605159,121.735006,121.581154,121.158062,120.859975,117.859871,115.455941, 118.542587,120.831128,120.783049,121.946551,123.571608,124.638953,126.994804,125.725529,120.408036,120.350342, 119.715705,118.052185,118.638744,118.263731,117.667556,116.638674,113.888579,110.234605,110.965400,110.705776, 111.582500,115.639343,109.621692,111.312044,111.225111,112.007502,113.166600,112.529097,111.089883,108.810324, 102.155170,99.605154,100.204021,105.951215,109.071121,109.428509,108.916574,104.048363,108.510890,106.608038, 105.545531,108.481913,106.395536,108.733051,110.317151,111.379658,112.316595,112.442164,110.037036,109.583056, 111.283066,109.534760,110.423402,111.080224,110.800109,108.607482,105.342689,106.202353,105.844965,106.617697, 107.004063,107.515998,107.004063,105.767692,108.298389,107.796113,107.979637,106.453491,108.047251,107.255201, 107.921682,109.892149,109.882489,111.563182,115.021157,111.350680,110.645562,115.204681,116.421734,115.426842, 117.049579,118.392201,117.841629,116.798441,117.436526,116.961193,113.274931,112.634686,112.256359,108.977526, 110.757603,110.287119,113.779367,115.224769,115.729205,114.225599,115.321776,114.497218,114.283803,114.759136, 113.827870,112.799597,111.751923,115.467287,114.739735,114.691232,112.159352,112.692890,109.792384,109.113336, 107.182899,108.007458,105.718095,102.856393,104.117482,104.020475,105.359170,104.796530,103.622747,105.485279, 104.107781,102.109440,102.196746,99.635764,97.685926,93.563134,94.057869,95.580877,96.968075,94.474998,96.541245, 94.222780,93.766848,93.892957,93.417623,98.384375,96.463639,96.997177,90.623825,91.273771,94.426495,93.543732, 91.652098,93.466127,93.708644,91.696830,92.662368,92.642862,91.940652,91.384737,91.667571,94.252091,95.695522, 93.881481,93.666917,94.486161,92.350274,93.725434,94.369126,94.515420,94.300856,98.045972,98.260536,98.992004, 100.464693,99.352862,98.533617,98.621394,98.670158,99.733225,99.986801,101.995899,103.351553,103.185754,103.302789, 103.293036,104.083021,103.507600,103.058966,102.590827,105.019300,106.852847,106.296931,107.272222,108.374300, 107.096670,108.218254,105.858050,105.975085,106.326190,107.711103,109.271568,109.330085,107.135681,104.824242, 104.268327,104.482891,103.351553,103.068719,102.483545,101.771582,95.402934,92.486815,91.423748,91.326219,92.828167, 91.862629,90.936103,90.981767,91.050455,91.668644,90.775704,88.646385,88.823011,92.120021,91.737332,92.787273, 92.434021,93.434899,94.622215,96.064657,97.752412,98.527602,98.468727,97.987913,96.614159,95.888032,96.084282, 96.780972,97.173473,97.085160,97.781850,96.977222,95.515156,95.632906,95.318905,95.721219,93.542837,93.317149, 94.111964,93.758713,94.298402,91.649019,90.314515,91.835457,92.630272,93.807775,94.092339,93.209211,93.739088, 94.141401,94.867529,95.161904,95.593656,95.053967,96.937972,96.928160,97.958475,97.997725,98.086038,97.565974, 96.810409,95.515156,94.857716,101.019984,102.383926,102.256363,104.061868,102.521301,103.806742,103.885243,106.032880, 106.910897,107.344972,106.545878,106.476821,106.723455,108.005951,107.907298,107.749452,107.611337,107.887567, 107.049012,107.384434,106.575474,106.121668,105.500150,105.381766,104.572806,104.671460,105.292978,106.279514, 106.249917,106.901031,104.099269,101.741448,104.020346,106.496551,110.265119,114.013955,113.372707,112.050749, 112.040883,112.021153,113.076746,111.192462,111.360173,111.567346,112.415767,110.669598,111.527885,111.005021, 111.478558,111.527885,112.356575,112.524286,114.487492,114.734126,115.760124,115.404971,116.046219,115.967296, 115.888373,115.543086,115.483894,115.030087,116.065950,116.657871,114.033686,112.938631,112.188864,112.011287, 109.988889,110.087542,108.351239,107.931825,109.488725,110.133301,109.954803,106.890586,107.525246,104.827942, 106.216260,109.072229,109.032563,109.141645,110.797711,110.867126,110.301883,110.857210,110.639046,110.529963, 109.597807,108.576401,108.982980,108.199572,109.032563,110.103551,111.184456,112.999187,112.354610,114.228840, 114.228840,114.853583,115.002331,115.666741,115.974154,116.083236,115.319661,115.547742,116.281568,115.785740, 115.755990,114.853583,115.180830,115.051914,115.636991,116.926144,117.997132,118.116131,118.750791,118.254963, 118.046715,118.998705,118.988788,118.780540,118.998705,119.078037,118.968955,120.863018,120.922517,120.932434, 120.615104,120.337440,127.675694,127.457529,128.002940,129.202844,130.432497,130.938241,131.315071,131.581537, 132.746769,134.469718,134.957721,134.793393,135.166865,136.142871,136.551200,135.973564,136.103034,136.371934, 136.431689,139.220278,138.393660,139.210318,138.772112,138.951378,138.433497,138.114801,138.572927,138.632682, 138.423538,139.887547,140.116610,139.419462,140.883471,139.270074,140.843634,140.345672,140.066813,140.305835, 143.213935,143.532630,143.343405,143.074505,143.114342,144.179981,143.433038,143.074505,142.755809,142.586502, 141.052778,141.222086,140.475142,141.251963,140.624531,140.106650,141.859477,141.690170,143.054587,143.950919, 143.065343,143.203975,143.064546,146.002523,146.908814,146.460648,145.932808,148.352905,152.376439,153.332527, 152.635380,153.322568,156.100000,155.700000,155.470000,150.250000,152.540000,152.960000,153.990000,153.800000, 153.340000,153.870000,153.610000,153.670000,152.760000,153.180000,155.450000,153.930000,154.450000,155.370000, 154.990000,148.980000,145.320000,146.590000)v.train = prices[1:500]v.test = prices[-(1:500)]sampleSize = 30getJoinedLaggedData = function(dataVector){ data.len = length(dataVector) result = NULL for (i in seq(1, data.len - sampleSize + 1, by = 1)) { result = rbind(result, dataVector[seq(i, i + sampleSize - 1, by = 1)]) } list(input=matrix(result[,-sampleSize], ncol = sampleSize - 1), output = matrix(result[,sampleSize]))}m.train = getJoinedLaggedData(v.train)#build the modelindata = tf$placeholder(tf$float32, shape(471, sampleSize - 1))outData = tf$placeholder(tf$float32, shape(471, 1))w1 <- tf$Variable(tf$truncated_normal(shape(sampleSize - 1, 300),stddev = 1.0 / sqrt(sampleSize)))b1 = tf$Variable(tf$zeros(shape(300)))l1 <- tf$matmul(indata, w1) + b1w2 <- tf$Variable(tf$truncated_normal(shape(300, 50),stddev = 1.0 / sqrt(300)))b2 <- tf$Variable(tf$zeros(shape(50)))l2 <- tf$matmul(l1, w2) + b2w3 <- tf$Variable(tf$truncated_normal(shape(50, 1),stddev = 1.0 / sqrt(50)))b3 <- tf$Variable(tf$zeros(shape(1)))pred <- tf$matmul(l2, w3) + b3#loss functionloss <- tf$reduce_mean(tf$abs(tf$sub(x = pred, y = outData)))#traineroptimizer <- tf$train$GradientDescentOptimizer(0.00003)train.op <- optimizer$minimize(loss)#run the modelinit = tf$initialize_all_variables()sess = tf$Session()sess$run(init)for (i in 1:1000){ values = sess$run(list(train.op, loss, pred), feed_dict = dict(indata = m.train$input, outData = m.train$output)) loss_value = values[[2]] pred_value = values[[3]] print(loss_value)}myLoss = mean(abs(pred_value - m.train$output))print(sprintf("Tensorflow Loss: %f, My Loss: %f", loss_value, myLoss))