我在尝试使用Tensorflow中的tf.contrib.training.stratified_sample
来平衡类别。我做了一个简单的示例来测试它,从两个不平衡的类别中以平衡的方式抽取样本并进行验证,但却遇到了错误。
import tensorflow as tffrom tensorflow.python.framework import opsfrom tensorflow.python.framework import dtypesbatch_size = 10data = ['a']*9990+['b']*10labels = [1]*9990+[0]*10data_tensor = ops.convert_to_tensor(data, dtype=dtypes.string)label_tensor = ops.convert_to_tensor(labels)target_probs = [0.5,0.5]data_batch, label_batch = tf.contrib.training.stratified_sample( data_tensor, label_tensor, target_probs, batch_size, queue_capacity=2*batch_size)with tf.Session() as sess: d,l = sess.run(data_batch,label_batch)print('percentage "a" = %.3f' % (np.sum(l)/len(l)))
我得到的错误是:
Traceback (most recent call last): File "/home/jason/code/scrap.py", line 56, in <module> test_stratified_sample() File "/home/jason/code/scrap.py", line 47, in test_stratified_sample queue_capacity=2*batch_size) File "/usr/local/lib/python3.4/dist-packages/tensorflow/contrib/training/python/training/sampling_ops.py", line 191, in stratified_sample with ops.name_scope(name, 'stratified_sample', tensors + [labels]): File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/ops/math_ops.py", line 829, in binary_op_wrapper y = ops.convert_to_tensor(y, dtype=x.dtype.base_dtype, name="y") File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/ops.py", line 676, in convert_to_tensor as_ref=False) File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/ops.py", line 741, in internal_convert_to_tensor ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref) File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/constant_op.py", line 113, in _constant_tensor_conversion_function return constant(v, dtype=dtype, name=name) File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/constant_op.py", line 102, in constant tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape, verify_shape=verify_shape)) File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/tensor_util.py", line 374, in make_tensor_proto _AssertCompatible(values, dtype) File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/tensor_util.py", line 302, in _AssertCompatible (dtype.name, repr(mismatch), type(mismatch).__name__)) TypeError: Expected string, got list containing Tensors of type '_Message' instead.
错误信息并没有解释我做错了什么。我还尝试过直接输入原始数据和标签(不转换为张量),以及使用tf.train.slice_input_producer
来创建数据和标签张量的初始队列。
有没有人成功使用过stratified_sample
?我没能找到任何示例。
回答:
我已经修改了代码,使其对我来说可以工作。修改的总结如下:
- 使用
enqueue_many=True
来一次性加入一批具有不同标签的示例。否则,它期望一个单一的标量标签张量(这在由队列运行器评估时可能是随机的)。 - 第一个参数期望是一个张量列表。它应该有更好的错误信息(我想这就是你遇到的问题)。请提交一个拉取请求或在Github上打开一个问题来改进错误信息。
- 启动队列运行器。否则使用队列的代码会死锁。或者使用
Estimator
或MonitoredSession
,这样你就不需要担心这个问题了。 - (基于评论的编辑)
stratified_sample
不打乱数据,它只是接受/拒绝!所以如果你的数据没有随机化,考虑在抽样前通过slice_input_producer
(enqueue_many=False
)或shuffle_batch
(enqueue_many=True
)来随机化数据,如果你希望它以随机顺序输出的话。
修改后的代码(基于@人名的评论改进):
import numpyimport tensorflow as tffrom tensorflow.python.framework import opsfrom tensorflow.python.framework import dtypeswith tf.Graph().as_default(): batch_size = 100 data = ['a']*9000+['b']*1000 labels = [1]*9000+[0]*1000 data_tensor = ops.convert_to_tensor(data, dtype=dtypes.string) label_tensor = ops.convert_to_tensor(labels, dtype=dtypes.int32) shuffled_data, shuffled_labels = tf.train.slice_input_producer( [data_tensor, label_tensor], shuffle=True, capacity=3*batch_size) target_probs = numpy.array([0.5,0.5]) data_batch, label_batch = tf.contrib.training.stratified_sample( [shuffled_data], shuffled_labels, target_probs, batch_size, queue_capacity=2*batch_size) with tf.Session() as session: tf.local_variables_initializer().run() tf.global_variables_initializer().run() coordinator = tf.train.Coordinator() tf.train.start_queue_runners(session, coord=coordinator) num_iter = 10 sum_ones = 0. for _ in range(num_iter): d, l = session.run([data_batch, label_batch]) count_ones = l.sum() sum_ones += float(count_ones) print('percentage "a" = %.3f' % (float(count_ones) / len(l))) print('Overall: {}'.format(sum_ones / (num_iter * batch_size))) coordinator.request_stop() coordinator.join()
输出结果:
percentage "a" = 0.480percentage "a" = 0.440percentage "a" = 0.580percentage "a" = 0.570percentage "a" = 0.580percentage "a" = 0.520percentage "a" = 0.480percentage "a" = 0.460percentage "a" = 0.390percentage "a" = 0.530Overall: 0.503