我正在使用Tensorflow进行机器学习工作。
问题:
我无法弄清楚如何将类名转换为类索引。
示例:
预期映射:
Car ---> 0Bike ---> 1Boat ---> 2
代码:
#!/usr/bin/env python3.6import tensorflow as tfnames = [ "Car", "Bus", "Boat"]_, class_name = tf.TextLineReader(skip_header_lines=1).read( tf.train.string_input_producer(tf.gfile.Glob("input_file.csv")))# 我想知道是否可以这样做:# print(sess.run(class_name)) --> "Car"# class_index = f(class_name, names)# print(sess.run(class_index)) --> 0
input_file.csv :
class_nameCarCarBoatBike...
回答:
最简单的方法是这样的:
class_index = tf.reduce_min(tf.where(tf.equal(names, class_name)))
请注意,当类在names
中存在时,它工作得很好,但当类不存在时(如你的例子中的Bike
),它会返回263 − 1。你可以通过移除tf.reduce_min
来避免这种情况,但这样class_index
将评估为数组,而不是标量。
完整的可运行代码:
names = ["Car", "Bus", "Boat"]_, class_name = tf.TextLineReader(skip_header_lines=1).read( tf.train.string_input_producer(tf.gfile.Glob("input_file.csv")))class_index = tf.reduce_min(tf.where(tf.equal(names, class_name)))with tf.Session() as session: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for i in range(4): print(class_name.eval()) # Car, Car, Boat, Bike for i in range(4): print(class_index.eval()) # 0, 0, 2, 9223372036854775807 coord.request_stop() coord.join(threads)