根据类别返回numpy数组的分布样本

注意:我对这个任务依赖于numpy。

我正在尝试编写一个能够完成以下目标的单一函数:

  1. 将数据集加载到numpy数组中
  2. 将数据集分成5个“相等”(或尽可能相等)的折叠
  3. 对于每个折叠,确保数据按80/20的比例分配给训练和测试
  4. 这里有一个难点。原始输入数据集是“标记”的,最后一列包含分类。折叠需要保持与输入集相同的类别大小分布。

例如,如果我有input=100 samples(rows),并且有两个类别(由最后一列的值表示),A和B,分别占33%和67%,那么我应该创建5个折叠,每个折叠包含20个样本,其中6或7个样本是A,13或14个样本是B。

这就是我正在努力实现的。我不知道如何确保折叠本身包含适当的类别抽样分布。

以下是迄今为止我的尝试代码。我已经编写了两个函数,一个可以告诉我输入类别的分布,另一个可以创建5个折叠。然而,我需要找到一种方法将它们结合起来,创建5个保持各自分布的折叠。

import numpydef csv_to_array(file):    # 打开文件,并以逗号分隔加载逗号分隔值文件    data = open(file, 'r')    data = numpy.loadtxt(data, delimiter=',')    # 循环处理数组中的数据    for index in range(len(data)):        # 使用try catch尝试转换为float,如果无法转换为float,则转换为0        try:            data[index] = [float(x) for x in data[index]]        except Exception:            data[index] = 0        except ValueError:            data[index] = 0    # 返回现在类型格式化的数据    return datadef class_distribution(dataset):    dataset = numpy.asarray(dataset)    num_total_rows = dataset.shape[0]    num_columns = dataset.shape[1]    classes = dataset[:,num_columns-1]    classes = numpy.unique(classes)    for aclass in classes:        total = 0        for row in dataset:            if numpy.array_equal(aclass, row[-1]):                total = total + 1            else:                continue        print(aclass, " Has: ", ((total/num_total_rows) * 100))        print(aclass, " : ", total)def create_folds(dataset):    # print("DATASET", dataset)    numpy.random.shuffle(dataset)    num_rows = dataset.shape[0]    split_mark = int(num_rows / 5)    folds = []    fold_sets = []    temp1 = dataset[:split_mark]    # print("TEMP1", temp1)    temp2 = dataset[split_mark:split_mark*2]    # print("TEMP2", temp2)    temp3 = dataset[split_mark*2:split_mark*3]    # print("TEMP3", temp3)    temp4 = dataset[split_mark*3:split_mark*4]    # print("TEMP4", temp4)    temp5 = dataset[split_mark*4:]    # print("TEMP5", temp5)    folds.append(temp1)    folds.append(temp2)    folds.append(temp3)    folds.append(temp4)    folds.append(temp5)    folds = numpy.asarray(folds)    # print(folds)    return foldsdef main():    print("BEGINNING CFV")    ecoli = csv_to_array('Classification/ecoli.csv')    # print(len(ecoli))    class_distribution(ecoli)    create_folds(ecoli)main()

这是我正在处理的csv文件示例,最后一列表示类别。这是对来自UCI机器学习库的ecoli数据集的修改:

0.61,0.45,0.48,0.5,0.48,0.35,0.41,00.17,0.38,0.48,0.5,0.45,0.42,0.5,00.44,0.35,0.48,0.5,0.55,0.55,0.61,00.43,0.4,0.48,0.5,0.39,0.28,0.39,00.42,0.35,0.48,0.5,0.58,0.15,0.27,00.23,0.33,0.48,0.5,0.43,0.33,0.43,00.37,0.52,0.48,0.5,0.42,0.42,0.36,00.29,0.3,0.48,0.5,0.45,0.03,0.17,00.22,0.36,0.48,0.5,0.35,0.39,0.47,00.23,0.58,0.48,0.5,0.37,0.53,0.59,00.47,0.47,0.48,0.5,0.22,0.16,0.26,00.54,0.47,0.48,0.5,0.28,0.33,0.42,00.51,0.37,0.48,0.5,0.35,0.36,0.45,00.4,0.35,0.48,0.5,0.45,0.33,0.42,00.44,0.34,0.48,0.5,0.3,0.33,0.43,00.44,0.49,0.48,0.5,0.39,0.38,0.4,00.43,0.32,0.48,0.5,0.33,0.45,0.52,00.49,0.43,0.48,0.5,0.49,0.3,0.4,00.47,0.28,0.48,0.5,0.56,0.2,0.25,00.32,0.33,0.48,0.5,0.6,0.06,0.2,00.34,0.35,0.48,0.5,0.51,0.49,0.56,00.35,0.34,0.48,0.5,0.46,0.3,0.27,00.38,0.3,0.48,0.5,0.43,0.29,0.39,00.38,0.44,0.48,0.5,0.43,0.2,0.31,00.41,0.51,0.48,0.5,0.58,0.2,0.31,00.34,0.42,0.48,0.5,0.41,0.34,0.43,00.51,0.49,0.48,0.5,0.53,0.14,0.26,00.25,0.51,0.48,0.5,0.37,0.42,0.5,00.29,0.28,0.48,0.5,0.5,0.42,0.5,00.25,0.26,0.48,0.5,0.39,0.32,0.42,00.24,0.41,0.48,0.5,0.49,0.23,0.34,00.17,0.39,0.48,0.5,0.53,0.3,0.39,00.04,0.31,0.48,0.5,0.41,0.29,0.39,00.61,0.36,0.48,0.5,0.49,0.35,0.44,00.34,0.51,0.48,0.5,0.44,0.37,0.46,00.28,0.33,0.48,0.5,0.45,0.22,0.33,00.4,0.46,0.48,0.5,0.42,0.35,0.44,00.23,0.34,0.48,0.5,0.43,0.26,0.37,00.37,0.44,0.48,0.5,0.42,0.39,0.47,00,0.38,0.48,0.5,0.42,0.48,0.55,00.39,0.31,0.48,0.5,0.38,0.34,0.43,00.3,0.44,0.48,0.5,0.49,0.22,0.33,00.27,0.3,0.48,0.5,0.71,0.28,0.39,00.17,0.52,0.48,0.5,0.49,0.37,0.46,00.36,0.42,0.48,0.5,0.53,0.32,0.41,00.3,0.37,0.48,0.5,0.43,0.18,0.3,00.26,0.4,0.48,0.5,0.36,0.26,0.37,00.4,0.41,0.48,0.5,0.55,0.22,0.33,00.22,0.34,0.48,0.5,0.42,0.29,0.39,00.44,0.35,0.48,0.5,0.44,0.52,0.59,00.27,0.42,0.48,0.5,0.37,0.38,0.43,00.16,0.43,0.48,0.5,0.54,0.27,0.37,00.06,0.61,0.48,0.5,0.49,0.92,0.37,10.44,0.52,0.48,0.5,0.43,0.47,0.54,10.63,0.47,0.48,0.5,0.51,0.82,0.84,10.23,0.48,0.48,0.5,0.59,0.88,0.89,10.34,0.49,0.48,0.5,0.58,0.85,0.8,10.43,0.4,0.48,0.5,0.58,0.75,0.78,10.46,0.61,0.48,0.5,0.48,0.86,0.87,10.27,0.35,0.48,0.5,0.51,0.77,0.79,10.52,0.39,0.48,0.5,0.65,0.71,0.73,10.29,0.47,0.48,0.5,0.71,0.65,0.69,10.55,0.47,0.48,0.5,0.57,0.78,0.8,10.12,0.67,0.48,0.5,0.74,0.58,0.63,10.4,0.5,0.48,0.5,0.65,0.82,0.84,10.73,0.36,0.48,0.5,0.53,0.91,0.92,10.84,0.44,0.48,0.5,0.48,0.71,0.74,10.48,0.45,0.48,0.5,0.6,0.78,0.8,10.54,0.49,0.48,0.5,0.4,0.87,0.88,10.48,0.41,0.48,0.5,0.51,0.9,0.88,10.5,0.66,0.48,0.5,0.31,0.92,0.92,10.72,0.46,0.48,0.5,0.51,0.66,0.7,10.47,0.55,0.48,0.5,0.58,0.71,0.75,10.33,0.56,0.48,0.5,0.33,0.78,0.8,10.64,0.58,0.48,0.5,0.48,0.78,0.73,10.11,0.5,0.48,0.5,0.58,0.72,0.68,10.31,0.36,0.48,0.5,0.58,0.94,0.94,10.68,0.51,0.48,0.5,0.71,0.75,0.78,10.69,0.39,0.48,0.5,0.57,0.76,0.79,10.52,0.54,0.48,0.5,0.62,0.76,0.79,10.46,0.59,0.48,0.5,0.36,0.76,0.23,10.36,0.45,0.48,0.5,0.38,0.79,0.17,10,0.51,0.48,0.5,0.35,0.67,0.44,10.1,0.49,0.48,0.5,0.41,0.67,0.21,10.3,0.51,0.48,0.5,0.42,0.61,0.34,10.61,0.47,0.48,0.5,0,0.8,0.32,10.63,0.75,0.48,0.5,0.64,0.73,0.66,10.71,0.52,0.48,0.5,0.64,1,0.99,10.72,0.42,0.48,0.5,0.65,0.77,0.79,20.79,0.41,0.48,0.5,0.66,0.81,0.83,20.83,0.48,0.48,0.5,0.65,0.76,0.79,20.69,0.43,0.48,0.5,0.59,0.74,0.77,20.79,0.36,0.48,0.5,0.46,0.82,0.7,20.78,0.33,0.48,0.5,0.57,0.77,0.79,20.75,0.37,0.48,0.5,0.64,0.7,0.74,20.59,0.29,0.48,0.5,0.64,0.75,0.77,20.67,0.37,0.48,0.5,0.54,0.64,0.68,20.66,0.48,0.48,0.5,0.54,0.7,0.74,20.64,0.46,0.48,0.5,0.48,0.73,0.76,20.76,0.71,0.48,0.5,0.5,0.71,0.75,20.84,0.49,0.48,0.5,0.55,0.78,0.74,20.77,0.55,0.48,0.5,0.51,0.78,0.74,20.81,0.44,0.48,0.5,0.42,0.67,0.68,20.58,0.6,0.48,0.5,0.59,0.73,0.76,20.63,0.42,0.48,0.5,0.48,0.77,0.8,20.62,0.42,0.48,0.5,0.58,0.79,0.81,20.86,0.39,0.48,0.5,0.59,0.89,0.9,20.81,0.53,0.48,0.5,0.57,0.87,0.88,20.87,0.49,0.48,0.5,0.61,0.76,0.79,20.47,0.46,0.48,0.5,0.62,0.74,0.77,20.76,0.41,0.48,0.5,0.5,0.59,0.62,20.7,0.53,0.48,0.5,0.7,0.86,0.87,20.64,0.45,0.48,0.5,0.67,0.61,0.66,20.81,0.52,0.48,0.5,0.57,0.78,0.8,20.73,0.26,0.48,0.5,0.57,0.75,0.78,20.49,0.61,1,0.5,0.56,0.71,0.74,20.88,0.42,0.48,0.5,0.52,0.73,0.75,20.84,0.54,0.48,0.5,0.75,0.92,0.7,20.63,0.51,0.48,0.5,0.64,0.72,0.76,20.86,0.55,0.48,0.5,0.63,0.81,0.83,20.79,0.54,0.48,0.5,0.5,0.66,0.68,20.57,0.38,0.48,0.5,0.06,0.49,0.33,20.78,0.44,0.48,0.5,0.45,0.73,0.68,20.78,0.68,0.48,0.5,0.83,0.4,0.29,30.63,0.69,0.48,0.5,0.65,0.41,0.28,30.67,0.88,0.48,0.5,0.73,0.5,0.25,30.61,0.75,0.48,0.5,0.51,0.33,0.33,30.67,0.84,0.48,0.5,0.74,0.54,0.37,30.74,0.9,0.48,0.5,0.57,0.53,0.29,30.73,0.84,0.48,0.5,0.86,0.58,0.29,30.75,0.76,0.48,0.5,0.83,0.57,0.3,30.77,0.57,0.48,0.5,0.88,0.53,0.2,30.74,0.78,0.48,0.5,0.75,0.54,0.15,30.68,0.76,0.48,0.5,0.84,0.45,0.27,30.56,0.68,0.48,0.5,0.77,0.36,0.45,30.65,0.51,0.48,0.5,0.66,0.54,0.33,30.52,0.81,0.48,0.5,0.72,0.38,0.38,30.64,0.57,0.48,0.5,0.7,0.33,0.26,30.6,0.76,1,0.5,0.77,0.59,0.52,30.69,0.59,0.48,0.5,0.77,0.39,0.21,30.63,0.49,0.48,0.5,0.79,0.45,0.28,30.71,0.71,0.48,0.5,0.68,0.43,0.36,30.68,0.63,0.48,0.5,0.73,0.4,0.3,30.74,0.49,0.48,0.5,0.42,0.54,0.36,40.7,0.61,0.48,0.5,0.56,0.52,0.43,40.66,0.86,0.48,0.5,0.34,0.41,0.36,40.73,0.78,0.48,0.5,0.58,0.51,0.31,40.65,0.57,0.48,0.5,0.47,0.47,0.51,40.72,0.86,0.48,0.5,0.17,0.55,0.21,40.67,0.7,0.48,0.5,0.46,0.45,0.33,40.67,0.81,0.48,0.5,0.54,0.49,0.23,40.67,0.61,0.48,0.5,0.51,0.37,0.38,40.63,1,0.48,0.5,0.35,0.51,0.49,40.57,0.59,0.48,0.5,0.39,0.47,0.33,40.71,0.71,0.48,0.5,0.4,0.54,0.39,40.66,0.74,0.48,0.5,0.31,0.38,0.43,40.67,0.81,0.48,0.5,0.25,0.42,0.25,40.64,0.72,0.48,0.5,0.49,0.42,0.19,40.68,0.82,0.48,0.5,0.38,0.65,0.56,40.32,0.39,0.48,0.5,0.53,0.28,0.38,40.7,0.64,0.48,0.5,0.47,0.51,0.47,40.63,0.57,0.48,0.5,0.49,0.7,0.2,40.69,0.65,0.48,0.5,0.63,0.48,0.41,40.43,0.59,0.48,0.5,0.52,0.49,0.56,40.74,0.56,0.48,0.5,0.47,0.68,0.3,40.71,0.57,0.48,0.5,0.48,0.35,0.32,40.61,0.6,0.48,0.5,0.44,0.39,0.38,40.59,0.61,0.48,0.5,0.42,0.42,0.37,40.74,0.74,0.48,0.5,0.31,0.53,0.52,4

回答:

在采纳了@AlexL的建议后,我查看了StratifiedKFold代码,并开发了以下两个修改后的函数:

# 此函数返回类别列表及其相关权重(即分布)def class_distribution(dataset):    dataset = numpy.asarray(dataset)    num_total_rows = dataset.shape[0]    num_columns = dataset.shape[1]    classes = dataset[:, num_columns - 1]    classes = numpy.unique(classes)    class_weights = []    # 逐个循环处理类别    for aclass in classes:        total = 0        weight = 0        for row in dataset:            if numpy.array_equal(aclass, row[-1]):                total = total + 1            else:                continue        weight = float((total / num_total_rows))        class_weights.append(weight)    class_weights = numpy.asarray(class_weights)    return classes, class_weights# 此函数对分类执行k交叉折叠验证def cross_fold_validation_classification(dataset, k):    temp_dataset = numpy.asarray(dataset)    classes, class_weights = class_distribution(temp_dataset)    total_num_rows = temp_dataset.shape[0]    data = numpy.copy(temp_dataset)    total_fold_array = []    for _ in range(k):        curr_fold_array = []        # 循环处理每个类别及其相关权重        for a_class, a_class_weight in zip(classes, class_weights):            numpy.random.shuffle(data)            num_added = 0            num_to_add = float((((a_class_weight * total_num_rows)) / k))            tot = 0            for row in data:                curr = row[-1]                if num_added >= num_to_add:                    break                else:                    if (a_class == curr):                        curr_fold_array.append(row)                        num_added = num_added + 1                        numpy.delete(data, tot)                tot = tot + 1        total_fold_array.append(curr_fold_array)return total_fold_array

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注