我在Coursera的课程测试中发现了关于批处理生成器循环的问题:
def batch_generator(items, batch_size): for i in range(0, len(list(items)), batch_size): yield list[i:i+batch_size]# 测试批处理生成器def _test_items_generator(): for i in range(10): yield i print(i)grader.set_answer("a4FK1", list(map(lambda x: len(x), batch_generator(_test_items_generator(), 3))))
错误看起来像这样:
TypeError Traceback (most recent call last)<ipython-input-85-a91baa3cf6fa> in <module>() 6 7 print(i)----> 8 grader.set_answer("a4FK1", list(map(lambda x: len(x), batch_generator(_test_items_generator(), 3))))<ipython-input-84-4e82a37b7646> in batch_generator(items, batch_size) 12 """ 13 for i in range(0, len(list(items)), batch_size):---> 14 yield list[i:i+batch_size] 15 16 ### YOUR CODE HERETypeError: 'type' object is not subscriptable
我不知道应该在哪里修复我的问题。
回答:
你试图对内置的list
对象进行索引操作。你需要将输入(items
)转换为列表,然后对其进行索引操作:
def batch_generator(items, batch_size): l = list(items) for i in range(0, len(l), batch_size): yield l[i:i+batch_size]
如果你不想转换成列表(例如当items
本身是一个生成器时),你也可以直接循环items
输入:
def batch_generator(items, batch_size): res = [] for item in items: res.append(item) if len(res) == batch_size: yield(res) res = [] if len(res) > 0: yield(res)
请注意,我们需要在循环结束时检查res
中是否有剩余数据,如果有,也需要yield出来。