(在Python、机器学习和TensorFlow方面完全是新手)
我正在尝试将TensorFlow官方文档中的TensorFlow线性模型教程适应到ICU机器学习库中提供的Abalone数据集。目的是根据其他给定数据来猜测鲍鱼的环数(年龄)。
运行下面的程序时,我得到了以下结果:
File "/home/lawrence/tensorflow3.5/lib/python3.5/site-packages/tensorflow /python/ops/lookup_ops.py", line 220, in lookup(self._key_dtype, keys.dtype))TypeError: Signature mismatch. Keys must be dtype <dtype: 'string'>, got <dtype: 'int32'>.
错误是在lookup_ops.py的第220行抛出的,文档中记录为当以下情况发生时抛出:
Raises: TypeError: when `keys` or `default_value` doesn't match the table data types.
通过调试parse_csv()
,似乎所有张量都是以正确的类型创建的。
你能解释一下哪里出了问题吗?我认为我遵循了教程的代码逻辑,但无法解决这个问题。
源代码:
import tensorflow as tfimport shutil_CSV_COLUMNS = [ 'sex', 'length', 'diameter', 'height', 'whole_weight', 'shucked_weight', 'viscera_weight', 'shell_weight', 'rings']_CSV_COLUMN_DEFAULTS = [['M'], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0]]_NUM_EXAMPLES = { 'train': 3000, 'validation': 1177,}def build_model_columns(): """Builds a set of wide feature columns.""" # Continuous columns sex = tf.feature_column.categorical_column_with_hash_bucket('sex', hash_bucket_size=1000) length = tf.feature_column.numeric_column('length', dtype=tf.float32) diameter = tf.feature_column.numeric_column('diameter', dtype=tf.float32) height = tf.feature_column.numeric_column('height', dtype=tf.float32) whole_weight = tf.feature_column.numeric_column('whole_weight', dtype=tf.float32) shucked_weight = tf.feature_column.numeric_column('shucked_weight', dtype=tf.float32) viscera_weight = tf.feature_column.numeric_column('viscera_weight', dtype=tf.float32) shell_weight = tf.feature_column.numeric_column('shell_weight', dtype=tf.float32) base_columns = [sex, length, diameter, height, whole_weight, shucked_weight, viscera_weight, shell_weight] return base_columnsdef build_estimator(): """Build an estimator appropriate for the given model type.""" base_columns = build_model_columns() return tf.estimator.LinearClassifier( model_dir="~/models/albones/", feature_columns=base_columns, label_vocabulary=_CSV_COLUMNS) def input_fn(data_file, num_epochs, shuffle, batch_size): """Generate an input function for the Estimator.""" assert tf.gfile.Exists(data_file), ( '%s not found. Please make sure you have either run data_download.py or ' 'set both arguments --train_data and --test_data.' % data_file) def parse_csv(value): print('Parsing', data_file) columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS) features = dict(zip(_CSV_COLUMNS, columns)) labels = features.pop('rings') return features, labels # Extract lines from input files using the Dataset API. dataset = tf.data.TextLineDataset(data_file) if shuffle: dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train']) dataset = dataset.map(parse_csv) # We call repeat after shuffling, rather than before, to prevent separate # epochs from blending together. dataset = dataset.repeat(num_epochs) dataset = dataset.batch(batch_size) iterator = dataset.make_one_shot_iterator() features, labels = iterator.get_next() return features, labelsdef main(unused_argv): # Clean up the model directory if present shutil.rmtree("/home/lawrence/models/albones/", ignore_errors=True) model = build_estimator() # Train and evaluate the model every `FLAGS.epochs_per_eval` epochs. for n in range(40 // 2): model.train(input_fn=lambda: input_fn( "/home/lawrence/abalone.data", 2, True, 40)) results = model.evaluate(input_fn=lambda: input_fn( "/home/lawrence/abalone.data", 1, False, 40)) # Display evaluation metrics print('Results at epoch', (n + 1) * 2) print('-' * 60) for key in sorted(results): print('%s: %s' % (key, results[key]))if __name__ == '__main__': tf.logging.set_verbosity(tf.logging.INFO) tf.app.run(main=main)
以下是来自abalone.names的数据集列的分类:
Name Data Type Meas. Description---- --------- ----- -----------Sex nominal M, F, [or] I (infant)Length continuous mm Longest shell measurementDiameter continuous mm perpendicular to lengthHeight continuous mm with meat in shellWhole weight continuous grams whole abaloneShucked weight continuous grams weight of meatViscera weight continuous grams gut weight (after bleeding)Shell weight continuous grams after being driedRings integer +1.5 gives the age in years
数据集条目按此顺序以逗号分隔的值出现,每个新条目占一行。
回答: