这个问题可能之前已经问过,但我没能找到相关的答案。
从数据集中持续获取批次数据的最简单方法是什么?TensorFlow中有没有内置的函数可以做到这一点?
例如:
for i in num_trains: x_batch, y_batch = get_batch(x_train, y_train, batch_size) sess.run(train_step, feed_dict={x:x_batch,y:y_batch})
如果没有这样的内置函数,你会如何实现它?我自己尝试过,但每次调用函数时,我无法弄清楚如何获取与之前不同的新批次数据。
谢谢!
回答:
你可以尝试以下方法:
# 提供批次数据def get_batch(inputX, inputY, batch_size): duration = len(inputX) for i in range(0,duration//batch_size): idx = i*batch_size yield inputX[idx:idx+batch_size], inputY[idx:idx+batch_size]
你也可以使用TensorFlow的dataset API
:
dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))dataset = dataset.batch(batch_size)
获取批次数据:
X = np.arange(100) Y = X batch = get_batch(X, Y, 5) batch_x, batch_y = next(batch) print(batch_x, batch_y) #[0 1 2 3 4] [0 1 2 3 4] batch_x, batch_y = next(batch) print(batch_x, batch_y) #[5 6 7 8 9] [5 6 7 8 9]
通常,为了在多个epochs
中遍历数据集,你会这样做:
for epoch in range(number of epoch): for step in range(size_of_dataset//batch_size): for x_batch, y_batch in get_batch(x_train, y_train, batch_size): sess.run(train_step, feed_dict={x:x_batch,y:y_batch})
使用dataset API
:
dataset = tf.data.Dataset.from_tensor_slices((X, Y)) dataset = dataset.batch(5) iterator = dataset.make_initializable_iterator() train_x, train_y = iterator.get_next() with tf.Session() as sess: sess.run(iterator.initializer) for i in range(2): print(sess.run([train_x, train_y])) #[array([0, 1, 2, 3, 4]), array([0, 1, 2, 3, 4])] #[array([5, 6, 7, 8, 9]), array([5, 6, 7, 8, 9])]