我注意到torch7中存在一种异常行为。我对torch7了解不多,所以不知道这种行为该如何解释或纠正。
我使用的是CIFAR-10数据集。我简单地从CIFAR-10中获取了一张图像的数据,然后将其保存到我的目录中。当我加载这个保存的图像时,发现它与原图不同。
这是我的代码 –
require 'image'i1 = testData.data[2] --从CIFAR-10获取数据image.save("1.png", i) --将数据保存为图像i2 = image.load("1.png") --加载保存的图像if(i1 == i2) then --检查图像1(i1)和图像2(i2)是否不同print("same") end
这种行为是预期的吗?我以为png
应该是无损的。
如果是的话,该如何纠正这种情况?
加载CIFAR-10数据集的代码 –
-- 加载数据集 trainData = { data = torch.Tensor(50000, 3072), labels = torch.Tensor(50000), size = function() return trsize end } for i = 0,4 do local subset = torch.load('cifar-10-batches-t7/data_batch_' .. (i+1) .. '.t7', 'ascii') trainData.data[{ {i*10000+1, (i+1)*10000} }] = subset.data:t() trainData.labels[{ {i*10000+1, (i+1)*10000} }] = subset.labels end trainData.labels = trainData.labels + 1 local subset = torch.load('cifar-10-batches-t7/test_batch.t7', 'ascii') testData = { data = subset.data:t():double(), labels = subset.labels[1]:double(), size = function() return tesize end } testData.labels = testData.labels + 1 testData.data = testData.data:reshape(10000,3,32,32)
回答:
==
运算符比较的是两个张量的指针,而不是内容:
a = torch.Tensor(3, 5):fill(1)b = torch.Tensor(3, 5):fill(1)print(a == b)> falseprint(a:eq(b):all())> true