如何将嵌套数组输入到SVM模型中

我的问题如下:

我有一个数组,其中包含对应于多个音频文件的特征向量。例如,如果有10个音频文件,那么这个数组的长度将是10。

我有一个特征本身就是一个列表(这个列表包含音频文件的特定特征的信息),对于给定的音频文件,特征向量看起来像这样:

array([0.03861840871664194, 187.72393405210002, 62.59881268743305,       0.2911392405063291,       array([4963.40332031, 3229.98046875, 2691.65039062, 3208.44726562,       4338.94042969, 4220.5078125 , 4166.67480469, 4801.90429688,       5555.56640625, 5910.86425781, 6115.4296875 , 5706.29882812,       4984.93652344, 2756.25      , 1991.82128906, 2551.68457031,       2734.71679688, 2906.98242188, 3143.84765625, 3219.21386719,       3186.9140625 , 3165.38085938, 3068.48144531, 2465.55175781,       2110.25390625, 2508.61816406, 2993.11523438, 3843.67675781,       4715.77148438, 5652.46582031, 5480.20019531, 5792.43164062,       5932.39746094, 6244.62890625, 6072.36328125, 6201.5625    ,       6158.49609375, 6201.5625    , 6233.86230469, 6061.59667969])],      dtype=object)

现在当我尝试将这些数据输入到SVM模型中时:

from sklearn import svmfrom sklearn.model_selection import train_test_splitfrom sklearn.metrics import confusion_matriximport matplotlib.pyplot as pltX_train, X_val, y_train, y_val = train_test_split(X,y,test_size=0.3)model  = svm.SVC()model.fit(X_train,y_train)yt_p = model.predict(X_train)yv_p = model.predict(X_val)

我得到了这个错误 ValueError: setting an array element with a sequence.

我应该如何构建我的特征向量,以便能够将其输入到SVM模型中?

编辑:

我在这里提供了一个X的示例

如果我们有5个音频文件,那么X将是:

array([[0.017455393927437918, 227.66237105624407, 32.42076654734572,        0.3867924528301887,        array([1851.85546875, 2433.25195312, 3057.71484375, 3079.24804688,       3079.24804688, 3068.48144531, 3046.94824219, 3359.1796875 ,       3908.27636719, 4618.87207031, 4618.87207031, 4521.97265625,       4091.30859375, 3111.54785156, 3100.78125   , 2863.91601562,       1561.15722656, 1119.7265625 , 1065.89355469,  947.4609375 ,        979.76074219,  990.52734375,  990.52734375, 1356.59179688,       2077.95410156, 2993.11523438, 3025.41503906, 3068.48144531,       3079.24804688, 3090.01464844, 3100.78125   , 3111.54785156,       2993.11523438, 3100.78125   , 3079.24804688, 2853.14941406,       1205.859375  , 1281.22558594, 1614.99023438, 2131.78710938,       2325.5859375 , 2034.88769531, 1916.45507812, 1744.18945312,       1851.85546875, 2357.88574219, 2368.65234375, 1916.45507812,       1959.52148438, 1959.52148438, 1754.95605469, 1787.25585938,       2207.15332031])],       [0.03861840871664194, 187.72393405210002, 62.59881268743305,        0.2911392405063291,        array([4963.40332031, 3229.98046875, 2691.65039062, 3208.44726562,       4338.94042969, 4220.5078125 , 4166.67480469, 4801.90429688,       5555.56640625, 5910.86425781, 6115.4296875 , 5706.29882812,       4984.93652344, 2756.25      , 1991.82128906, 2551.68457031,       2734.71679688, 2906.98242188, 3143.84765625, 3219.21386719,       3186.9140625 , 3165.38085938, 3068.48144531, 2465.55175781,       2110.25390625, 2508.61816406, 2993.11523438, 3843.67675781,       4715.77148438, 5652.46582031, 5480.20019531, 5792.43164062,       5932.39746094, 6244.62890625, 6072.36328125, 6201.5625    ,       6158.49609375, 6201.5625    , 6233.86230469, 6061.59667969])],       [0.042435441297643324, 128.81225073038124, 20.912528554426807,        0.313953488372093,        array([4349.70703125, 4242.04101562, 4274.34082031, 4123.60839844,       4457.37304688, 4834.20410156, 4661.93847656, 4306.640625  ,       4231.27441406, 4543.50585938, 4435.83984375, 6201.5625    ,       8817.84667969, 8817.84667969,  742.89550781,  721.36230469,        732.12890625,  732.12890625,  710.59570312,  721.36230469,        925.92773438, 1119.7265625 , 1141.25976562, 1431.95800781,       7762.71972656, 7934.98535156, 7891.91894531, 7332.05566406,       3789.84375   , 2799.31640625, 2831.61621094, 2217.91992188,        581.39648438,  602.9296875 , 2217.91992188, 2228.68652344,       2368.65234375, 2519.38476562, 2863.91601562, 3682.17773438,       3649.87792969, 4188.20800781, 4112.84179688])],       [0.006295381642571726, 130.28309914454434, 5.193614287487564,        0.2411764705882353,        array([7978.05175781, 8010.3515625 , 8118.01757812, 8430.24902344,       8257.98339844, 8451.78222656, 8591.74804688, 8677.88085938,       8796.31347656, 8850.14648438, 8796.31347656, 8925.51269531,       6244.62890625,  344.53125   ,  344.53125   , 1614.99023438,       2325.5859375 , 2971.58203125, 3316.11328125, 3617.578125  ,       3294.58007812, 2788.54980469, 2637.81738281, 2702.41699219,       2723.95019531, 3133.08105469, 3413.01269531, 5663.23242188,       5770.8984375 , 5577.09960938, 2228.68652344, 1604.22363281,       1690.35644531, 4123.60839844, 5566.33300781, 5803.19824219,       5749.36523438, 5846.26464844, 6772.19238281, 7073.65722656,       7622.75390625, 7859.61914062, 8236.45019531, 8441.015625  ,       8699.4140625 , 8807.08007812, 8742.48046875, 8667.11425781,       8710.18066406, 8947.04589844, 9140.84472656, 9130.078125  ,       8936.27929688, 8925.51269531, 8947.04589844, 8925.51269531,       9097.77832031, 9205.44433594, 9194.67773438, 9140.84472656,       9162.37792969, 9043.9453125 , 9162.37792969, 9108.54492188,       9183.91113281, 9280.81054688, 9270.04394531, 9108.54492188,       9076.24511719, 9356.17675781, 9226.97753906, 9216.2109375 ,       9248.51074219, 9140.84472656, 9237.74414062, 9334.64355469,       9259.27734375, 9226.97753906, 9216.2109375 , 9108.54492188,       9183.91113281, 9216.2109375 , 9248.51074219, 9259.27734375,       9183.91113281])],       [0.017070271599460656, 171.91660927761163, 26.854424936811768,        0.11188811188811189,        array([4715.77148438, 4629.63867188, 4898.80371094, 5275.63476562,       4941.87011719, 4532.73925781, 4618.87207031, 4995.703125  ,       4705.00488281, 4500.43945312, 4188.20800781, 4371.24023438,       4457.37304688, 4188.20800781, 4909.5703125 , 4877.27050781,       6761.42578125, 7708.88671875, 7719.65332031, 7956.51855469,       8484.08203125, 9033.17871094, 9043.9453125 , 9000.87890625,       9011.64550781, 9011.64550781, 9000.87890625, 9108.54492188,       8817.84667969, 6686.05957031, 1808.7890625 , 1830.32226562,       1851.85546875, 1636.5234375 , 1022.82714844, 1281.22558594,       1927.22167969, 1948.75488281, 1302.75878906, 1399.65820312,       1873.38867188, 1959.52148438, 7245.92285156, 9011.64550781,       9420.77636719, 9549.97558594, 9453.07617188, 9431.54296875,       9410.00976562, 9248.51074219, 9151.61132812, 9194.67773438,       8968.57910156, 8634.81445312, 8268.75      , 7439.72167969,       5501.73339844, 5232.56835938, 5103.36914062, 7052.12402344,       7299.75585938, 7127.49023438, 7192.08984375, 5673.99902344,       5523.26660156, 5986.23046875, 6729.12597656, 6309.22851562,       5135.66894531, 5081.8359375 , 5329.46777344, 5404.83398438])]],      dtype=object)

回答:

您可以通过两种方式将包含列表的特征输入到模型中:

  1. 将列表视为附加特征
  2. 使用您认为合适的函数(最小值、中位数、平均值、最大值、总和等)将列表的所有元素映射为单个数字。

尝试第一种方法:

# 将`X`转换为数据框X = pd.DataFrame(X)# 重命名列X.columns = ['feature_' + str(i + 1) for i in range(X.shape[1])]# 将包含列表的特征转换为长格式x = X['feature_5'].explode().to_frame()# 创建按观察计数器,以便我们可以透视x['observation_id'] = x.groupby(level=0).cumcount()# 转换为数据集并重命名所有列x = x.pivot(columns='observation_id', values='feature_5').fillna(0)x = x.add_prefix('list_element_')# 从X中删除`feature_5`X.drop(columns='feature_5', axis=1, inplace=True)# 将X和x连接在一起X = pd.concat([X, x], axis=1)# 继续之前的操作X_train, X_val, y_train, y_val = train_test_split(X,y,test_size=0.3)model  = svm.SVC()model.fit(X_train,y_train)

对于第二种方法没有正确答案,只有您能决定如何做,因为只有您知道这些列表的含义。然而,如果您想获取每个列表的平均值(例如)并将其用作特征:

# 获取每个列表的平均值means = [np.mean(array) for array in X[:, 4]]# 用`means`替换列表X[:, 4] = means

然后继续进行分割和拟合操作。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注