降低错误预测类别的权重

我目前正在开发一个简单的预测系统,用户需要回答一系列的是非题,根据他们的回答,一个预训练的模型(MLPClassifier)会预测一个类别,并询问用户预测是否正确。我不确定这是否可行,但我希望能够以某种在线学习的方式调整预训练模型的权重,使得网络(在那个会话中)不会再次预测同样的类别。目前,我只是将错误的回答添加到一个字典中,如果网络预测的类别已经在黑名单类别集中,则忽略该预测,但我觉得一定有比这更好的方法!我的分类器代码如下:

mlp = MLPClassifier(hidden_layer_sizes=(128,), max_iter=500, alpha=1e-4,                    solver='sgd', verbose=10, tol=1e-4, random_state=1,                    learning_rate_init=.1, )x_train, x_test, y_train, y_test = train_test_split(df.values[:, 0:8], df.label_idx, test_size=0.33,                                                    random_state=42)

预测的代码如下:

def receive_input():responses = []bad_guesses = []print("Answer questions (Yes/No) or enter END to make prediction")count = 0while count < len(questions):    print(questions[count])    response = input().lower().strip()    if response == 'end':        break    elif response == 'yes':        responses.append(1)    elif response == 'no':        responses.append(0)    else:        print('Invalid Input')        continue    count += 1    padded_responses = np.pad(np.array(responses), (0, 8 - len(responses)), 'constant', constant_values=(0, -1))    prob_pred = mlp.predict_proba(padded_responses.reshape(1, -1)).flatten()    index = np.argmax(prob_pred)    best_score = prob_pred[index]    guess = labels[index]    if best_score > 0.8 and guess not in bad_guesses:        print('Early guess is: ' + labels[index] + ' is this right ? (Yes/No)')        correct = input()        if correct == 'Yes':            break        elif correct == 'No':            bad_guesses.append(labels[index])pred = mlp.predict(np.array(responses).reshape(1, -1))print('Prediction is: ' + labels[pred[0]])

回答:

mlp.coefs_ 提供了一个列表,其中第 ith 个元素代表第 i 层对应的权重矩阵

此外,mlp.intercepts_ 提供了一个列表,其中第 ith 个元素代表第 i + 1 层对应的偏置向量

所以你可以尝试看看这些属性是否可以更改。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注