我想在Python中使用sklearn的MiniBatchDictionaryLearning实现字典学习的错误跟踪,以便记录错误在迭代过程中如何减少。我尝试了两种方法,但都没有真正奏效。设置如下:
- 输入数据 X,numpy数组形状为(n_samples, n_features)=(298143, 300)。这些是从形状为(642, 480, 3)的图像中生成的形状为(10, 10)的补丁。
- 字典学习参数:列数(或原子数)= 100,alpha = 2,转换算法 = OMP,总迭代次数 = 500(先保持较小,作为测试用例)
-
计算错误:在学习字典后,我根据学习到的字典再次编码原始图像。由于编码和原始图像都是形状为(642, 480, 3)的numpy数组,我目前只是计算元素级的欧几里得距离:
err = np.sqrt(np.sum(reconstruction – original)**2))
我使用这些参数进行了测试运行,全拟合能够产生一个错误较低的良好重建,这很好。现在让我们看一下这两种方法:
方法1: 每100次迭代保存一次学习到的字典,并记录错误。对于500次迭代,这给我们提供了5个每次100次迭代的运行。在每次运行后,我计算错误,然后使用当前学习到的字典作为下一次运行的初始化。
# 拟合初始字典,V,作为第一次运行
dico = MiniBatchDictionaryLearning(n_components = 100,
alpha = 2,
n_iter = 100,
transform_algorithm='omp')
dl = dico.fit(patches)
V = dl.components_
# 现在再进行4次运行。
# 注意温热重启参数,dict_init = V。
for i in range(n_runs):
print("Run %s..." % i, end = "")
dico = MiniBatchDictionaryLearning(n_components = 100,
alpha = 2,
n_iter = n_iterations,
transform_algorithm='omp',
dict_init = V)
dl = dico.fit(patches)
V = dl.components_
img_r = reconstruct_image(dico, V, patches)
err = np.sqrt(np.sum((img - img_r)**2))
print("Err = %s" % err)
问题:错误没有减少,而且相当高。字典也没有很好地学习到。
方法2:将输入数据X切分为,例如500个批次,并使用partial_fit()
方法进行部分拟合。
batch_size = 500
n_batches = X.shape[0] // batch_size
print(n_batches) # 596
for iternum in range(n_batches):
batch = patches[iternum*batch_size : (iternum+1)*batch_size]
V = dico.partial_fit(batch)
问题:这似乎要花费大约5000倍的时间。
我想知道是否有方法在拟合过程中检索错误?
回答:
我遇到了同样的问题,最终我能够使代码运行得更快。如果这对某人仍然有用,这里添加解决方案。关键是在构造MiniBatchDictionaryLearning
对象时,我们需要将n_iter
设置为一个较低的值(例如,1),这样每个partial_fit
就不会对单个批次运行太多轮次。
# 构造一个初始字典对象,注意之后会在循环内进行部分拟合,
# 在这里我们只指定对于partial_fit,它只需要在当前提供的批次上运行1轮(n_iter=1),批量大小=batch_size
# (否则默认情况下,它可以对单个partial_fit()运行多达1000次迭代,批量大小为3,并且在每个批次上运行,这会使单次partial_fit()运行非常慢。
# 由于我们自己控制轮次,并且在所有批次完成后重新启动,我们不需要在这里提供超过1次的迭代。
# 这将使代码执行得更快。
batch_size = 128 # 例如,
dico = MiniBatchDictionaryLearning(n_components = 100,
alpha = 2,
n_iter = 1, # 每个partial_fit()的轮次
batch_size = batch_size,
transform_algorithm='omp')
接下来是@ogrisel的代码:
n_updates = 0
for epoch in range(n_epochs):
for i in range(n_batches):
batch = patches[i * batch_size:(i + 1) * batch_size]
dico.partial_fit(batch)
n_updates += 1
if n_updates % 100 == 0:
img_r = reconstruct_image(dico, dico.components_, patches)
err = np.sqrt(np.sum((img - img_r)**2))
print("[epoch #%02d] Err = %s" % (epoch, err))