SVM 在模型拟合中卡住

我在使用SVM时遇到了一个问题,程序执行在model.fit(X_test, y_test)处卡住,这是SVM模型拟合的步骤。如何解决这个问题?这是我的代码:

# Make Predictions with Naive Bayes On The Iris Datasetimport collectionsfrom csv import readerfrom math import sqrt, exp, pifrom IPython.display import Imageimport matplotlib.pyplot as pltimport numpy as npimport pandas as pdimport seaborn as sns; sns.set()from sklearn.cross_validation import train_test_split from sklearn.ensemble import ExtraTreesClassifierfrom sklearn.externals.six import StringIOfrom sklearn.feature_selection import SelectKBest, chi2from sklearn import datasets, metricsfrom sklearn.linear_model import LogisticRegressionfrom sklearn.metrics import accuracy_score, classification_report, confusion_matrixfrom sklearn.naive_bayes import GaussianNBfrom sklearn import svmfrom sklearn import treefrom sklearn.tree import DecisionTreeClassifier, export_graphviz# Function to split the dataset def splitdataset(balance_data, column_count):     # Separating the target variable     X = balance_data.values[:, 1:column_count]     Y = balance_data.values[:, 0]     # Splitting the dataset into train and test     X_train, X_test, y_train, y_test = train_test_split(     X, Y, test_size = 0.3, random_state = 100)     return X, Y, X_train, X_test, y_train, y_test def importdata():     balance_data = pd.read_csv( 'dataExtended.txt', sep= ',')     row_count, column_count = balance_data.shape    # Printing the dataswet shape     print ("Dataset Length: ", len(balance_data))     print ("Dataset Shape: ", balance_data.shape)     print("Number of columns ", column_count)    # Printing the dataset obseravtions     print ("Dataset: ",balance_data.head())     balance_data['gold'] = balance_data['gold'].astype('category').cat.codes    balance_data['Program'] = balance_data['Program'].astype('category').cat.codes    return balance_data, column_count # Driver code def main():     print("hey")    # Building Phase     data,column_count = importdata()     X, Y, X_train, X_test, y_train, y_test = splitdataset(data, column_count)     #Create a svm Classifier    model = svm.SVC(kernel='linear',probability=True) # Linear Kernel    print('before fitting')    model.fit(X_test, y_test)    print('fitting over')    predicted = model.predict(X_test)       print('prediction over')    print(metrics.classification_report(y_test, predicted))    print('classification over')    print(metrics.confusion_matrix(y_test, predicted))    probs = model.predict_proba(X_test)    probs_list = list(probs)    y_pred=[None]*len(y_test)    y_pred_list = list(y_pred)    y_test_list = list(y_test)    i=0    threshold=0.7    while i<len(probs_list):            #print('probs ',probs_list[i][0])            if (probs_list[i][0]>=threshold) & (probs_list[i][1]<threshold):                   y_pred_list[i]=0                   i=i+1            elif (probs_list[i][0]<threshold) & (probs_list[i][1]>=threshold):                   y_pred_list[i]=1                   i=i+1            else:                    #print(y_pred[i])                   #print('i==> ',i, ' probs length ', len(probs_list), ' ', len(y_pred_list), ' ', len(y_test_list))                   y_pred_list.pop(i)                   y_test_list.pop(i)                   probs_list.pop(i)    #print(y_pred_list)    print('confusion matrix\n',confusion_matrix(y_test_list,y_pred_list))    print('classification report\n', classification_report(y_test_list,y_pred_list))    print('accuracy score', accuracy_score(y_test_list, y_pred_list))    print('Mean Absolute Error:', metrics.mean_absolute_error(y_test_list, y_pred_list))    print('Mean Squared Error:', metrics.mean_squared_error(y_test_list, y_pred_list))    print('Root Mean Squared Error:', np.sqrt(metrics.mean_squared_error(y_test_list, y_pred_list)))if __name__=="__main__":     main() 

回答:

这很可能是由于在初始化模型时将probability参数设置为True导致的。正如您在文档中可以看到的:

probability: bool, default=False

是否启用概率估计。这必须在调用fit之前启用,会减慢该方法的速度,因为它在内部使用5倍交叉验证,并且predict_proba可能与predict不一致。

这个问题在StackOverflow上已经讨论过,这里这里都有提到。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注