sklearn SVM 自定义核函数引发 ValueError: X.shape[0] 应等于 X.shape[1]

我在尝试实现一个自定义核函数,具体来说是指数卡方核函数,并将其作为参数传递给 sklearn 的 svm 函数,但在运行时引发了以下错误:ValueError: X.shape[0] 应等于 X.shape[1]

我了解到 numpy 的函数通过广播操作来加速计算,但无法解决这个错误。

代码如下:

import numpy as npfrom sklearn import svm, datasets# 导入鸢尾花数据集 (http://en.wikipedia.org/wiki/Iris_flower_data_set)iris = datasets.load_iris()train_features = iris.data[:, :2]  # 这里我们只使用前两个特征。train_labels = iris.targetdef my_kernel(x, y):    gamma = 1    return np.exp(-gamma * np.divide((x - y) ** 2, x + y))classifier = svm.SVC(kernel=my_kernel)classifier = classifier.fit(train_features, train_labels)print "Train Accuracy : " + str(classifier.score(train_features, train_labels))

有什么帮助吗?


回答:

我认为卡方核函数已经为您实现了(在 from sklearn.metrics.pairwise import chi2_kernel 中)。

如下所示

from functools import partialfrom sklearn import svm, datasetsfrom sklearn.metrics.pairwise import chi2_kernel# 导入鸢尾花数据集 (http://en.wikipedia.org/wiki/Iris_flower_data_set)iris = datasets.load_iris()train_features = iris.data[:, :2]  # 这里我们只使用前两个特征。train_labels = iris.targetmy_chi2_kernel = partial(chi2_kernel, gamma=1)classifier = svm.SVC(kernel=my_chi2_kernel)classifier = classifier.fit(train_features, train_labels)print("Train Accuracy : " + str(classifier.score(train_features, train_labels)))

====================

编辑:

所以问题实际上是关于如何实现卡方核函数。我尝试这样做:

def my_chi2_kernel(X):    gamma = 1    nom = np.power(X[:, np.newaxis] - X, 2)    denom = X[:, np.newaxis] + X    # 注意:我们需要修复一些条目,因为除以0是一个问题。    #       所以我们获取所有将成为0的分母的索引,并修复它们。    zero_denom_idx = denom == 0    nom[zero_denom_idx] = 0    denom[zero_denom_idx] = 1    return np.exp(-gamma * np.sum(nom / denom, axis=len(X.shape)))

本质上,原始尝试中的 x - yx + y 是错误的,因为它们不是成对的减法或加法。

有趣的是,自定义版本似乎比 sklearn 的 cython 版本更快(至少对于小数据集?)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注