根据TensorFlow数据集按键分组和按键批处理

我目前正在处理一个TensorFlow的问题,需要生成批次,其中批次中的所有张量都具有特定的键值。如果可能的话,我希望使用数据集API。这是可行的吗?

筛选、映射、应用等操作都是针对单个元素进行的,而我需要一种按键分组的方法。我发现了tf.data.experimental.group_by_window和tf.data.experimental.group_by_reducer,看起来很有前景,但还没有找到解决方案。

最好给出一个例子来说明:

dataset:feature,label1,word12,word23,word31,word13,word31,word11,word12,word23,word31,word13,word31,word11,word1

按“key”特征分组,最大批次大小为3,生成的批次如下:

batch1[[1,word1], [1,word1], [1,word1]]batch2[[1,word1], [1,word1], [1,word1]]batch3[[1,word1]]batch4[[2,word2] [2,word2]]batch5[[3,word3], [3,word3], [3,word3]]batch6[[3,word3]]

编辑:尽管示例中显示的顺序,但每个批次的顺序并不重要


回答:

我认为这可以实现你想要的转换:

import tensorflow as tfimport randomrandom.seed(100)# 输入数据label = list(range(15))# 打乱数据random.shuffle(label)# 根据标签数据生成特征feature = [lbl // 5 for lbl in label]batch_size = 3print('数据:')print(*zip(feature, label), sep='\n')with tf.Graph().as_default(), tf.Session() as sess:    # 从数据数组创建数据集    ds = tf.data.Dataset.from_tensor_slices({'feature': feature, 'label': label})    # 按窗口分组    ds = ds.apply(tf.data.experimental.group_by_window(        # 使用特征作为键        key_func=lambda elem: tf.to_int64(elem['feature']),        # 将每个窗口转换为批次        reduce_func=lambda _, window: window.batch(batch_size),        # 使用批次大小作为窗口大小        window_size=batch_size))    # 迭代器    iter = ds.make_one_shot_iterator().get_next()    # 显示数据集内容    print('结果:')    while True:        try:            print(sess.run(iter))        except tf.errors.OutOfRangeError: break

输出:

数据:(2, 11)(1, 8)(2, 12)(0, 3)(1, 9)(0, 0)(0, 4)(0, 1)(2, 10)(1, 5)(1, 6)(2, 14)(2, 13)(1, 7)(0, 2)结果:{'feature': array([0, 0, 0]), 'label': array([3, 0, 4])}{'feature': array([2, 2, 2]), 'label': array([11, 12, 10])}{'feature': array([1, 1, 1]), 'label': array([8, 9, 5])}{'feature': array([0, 0]), 'label': array([1, 2])}{'feature': array([1, 1]), 'label': array([6, 7])}{'feature': array([2, 2]), 'label': array([14, 13])}

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注