在 Scikit 中加载自定义数据集(类似于 20新闻组数据集)用于文本文档分类

我正在尝试运行这个 scikit 示例代码,用于我的 Ted Talks 自定义数据集。每个目录是一个主题,目录下是包含每个 Ted Talk 描述的文本文件。

这是我的数据集树形结构。如您所见,每个目录是一个主题,下面是包含描述的文本文件。

Topics/|-- Activism|   |-- 1149.txt|   |-- 1444.txt|   |-- 157.txt|   |-- 1616.txt|   |-- 1706.txt|   |-- 1718.txt|-- Adventure|   |-- 1036.txt|   |-- 1777.txt|   |-- 2930.txt|   |-- 2968.txt|   |-- 3027.txt|   |-- 3290.txt|-- Advertising|   |-- 3673.txt|   |-- 3685.txt|   |-- 6567.txt|   `-- 6925.txt|-- Africa|   |-- 1045.txt|   |-- 1072.txt|   |-- 1103.txt|   |-- 1112.txt|-- Aging|   |-- 1848.txt|   |-- 2495.txt|   |-- 2782.txt|-- Agriculture|   |-- 3469.txt|   |-- 4140.txt|   |-- 4733.txt|   |-- 4939.txt

我将数据集制作成这种形式,以模仿20news组,其树形结构如下:

20news-18828/|-- alt.atheism|   |-- 49960|   |-- 51060|   |-- 51119|-- comp.graphics|   |-- 37261|   |-- 37913|   |-- 37914|   |-- 37915|   |-- 37916|   |-- 37917|   |-- 37918|-- comp.os.ms-windows.misc|   |-- 10000|   |-- 10001|   |-- 10002|   |-- 10003|   |-- 10004|   |-- 10005 

原始代码(98-124行)中,这是直接从 scikit 加载训练和测试数据的方式。

print("Loading 20 newsgroups dataset for categories:")print(categories if categories else "all")data_train = fetch_20newsgroups(subset='train', categories=categories,                                shuffle=True, random_state=42,                                remove=remove)data_test = fetch_20newsgroups(subset='test', categories=categories,                               shuffle=True, random_state=42,                               remove=remove)print('data loaded')categories = data_train.target_names    # for case categories == Nonedef size_mb(docs):    return sum(len(s.encode('utf-8')) for s in docs) / 1e6data_train_size_mb = size_mb(data_train.data)data_test_size_mb = size_mb(data_test.data)print("%d documents - %0.3fMB (training set)" % (    len(data_train.data), data_train_size_mb))print("%d documents - %0.3fMB (test set)" % (    len(data_test.data), data_test_size_mb))print("%d categories" % len(categories))print()# split a training set and a test sety_train, y_test = data_train.target, data_test.target

由于这个数据集在 Scikit 中是可用的,其标签等都是内置的。对于我的情况,我知道如何加载数据集(第84行)

dataset = load_files('./TED_dataset/Topics/')

在那之后我不知道该做什么。我想知道我应该如何将这些数据分割成训练和测试集,并从我的数据集中生成这些标签:

data_train.data,  data_test.data 

总的来说,我只是想加载我的数据集,并在没有错误的情况下运行这个代码。我已经上传了数据集,供那些想查看的人使用。

我参考了这个问题,它简要地讨论了测试-训练加载。我还想知道如何从我的数据集中获取 data_train.target_names。

编辑:

我尝试获取训练和测试集,但返回了错误:

dataset = load_files('./TED_dataset/Topics/')train, test = train_test_split(dataset, train_size = 0.8)

更新后的代码在这里


回答:

我想你正在寻找类似这样的东西:

In [1]: from sklearn.datasets import load_filesIn [2]: from sklearn.cross_validation import train_test_splitIn [3]: bunch = load_files('./Topics')In [4]: X_train, X_test, y_train, y_test = train_test_split(bunch.data, bunch.target, test_size=.4)# 然后继续训练你的模型并进行验证。

请注意,bunch.target 是一个整数数组,这些整数是存储在 bunch.target_names 中的类别名称的索引。

In [14]: X_test[:2]Out[14]:['Psychologist Philip Zimbardo asks, "Why are boys struggling?" He shares some stats (lower graduation rates, greater worries about intimacy and relationships) and suggests a few reasons -- and challenges the TED community to think about solutions.Philip Zimbardo was the leader of the notorious 1971 Stanford Prison Experiment -- and an expert witness at Abu Ghraib. His book The Lucifer Effect explores the nature of evil; now, in his new work, he studies the nature of heroism.', 'Human growth has strained the Earth\'s resources, but as Johan Rockstrom reminds us, our advances also give us the science to recognize this and change behavior. His research has found nine "planetary boundaries" that can guide us in protecting our planet\'s many overlapping ecosystems.If Earth is a self-regulating system, it\'s clear that human activity is capable of disrupting it. Johan Rockstrom has led a team of scientists to define the nine Earth systems that need to be kept within bounds for Earth to keep itself in balance.']In [15]: y_test[:2]Out[15]: array([ 84, 113])In [16]: [bunch.target_names[idx] for idx in y_test[:2]]Out[16]: ['Education', 'Global issues']

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注