在逻辑回归中理解 np.where()

我目前正在学习由Andrew Ng在Coursera上教授的深度学习专业课程。在第一个作业中,我需要定义一个预测函数,并想知道我的替代解决方案是否与实际解决方案一样有效。

请告诉我我对np.where()函数的理解是否正确,我已经在代码中的“ALTERNATIVE SOLUTION COMMENTS”部分进行了评论。另外,如果能检查我在“ACTUAL SOLUTION COMMENTS”部分的理解,我将不胜感激。

当我尝试增加X中的示例/输入数量时,当前数量(m = 3)增加到4、5等等,使用np.where()的替代解决方案也能正常工作。

请告诉我你的想法,以及这两个解决方案是否同样好!谢谢。

def predict(w, b, X):    '''    使用学习到的逻辑回归参数(w, b)预测标签是0还是1    参数:    w -- 权重,形状为(num_px * num_px * 3, 1)的numpy数组    b -- 偏置,一个标量    X -- 形状为(num_px * num_px * 3, 示例数量)的数据    返回:    Y_prediction -- 包含X中所有示例预测(0/1)的numpy数组(向量)    '''    m = X.shape[1]    Y_prediction = np.zeros((1,m))    # 初始化Y_prediction为零数组     w = w.reshape(X.shape[0], 1)    # 计算向量"A",预测图片中存在猫的概率    ### START CODE HERE ### (≈ 1 line of code)    A = sigmoid(np.dot(w.T, X) + b)   # 注意:A的形状总是(1,m)行向量    ### END CODE HERE ###    for i in range(A.shape[1]):       # for i in range(A中的示例数量 = 我们的集合中的示例数量)        # 将概率A[0,i]转换为实际预测p[0,i]        ### START CODE HERE ### (≈ 4 lines of code)        Y_prediction[0, i] = 1 if A[0, i] > 0.5 else 0        '''         ACTUAL SOLUTION COMMENTS:         这表示:        如果A中的第i个值大于0.5,则将Y_prediction的第i个值更新为1。         否则,将Y_prediction的第i个值更新为0。         '''        '''         ALTERNATIVE SOLUTION COMMENTS:        为了简化代码,你可以删除for循环和顶部的Y_prediction变量,         然后使用以下一行代码:         return np.where(A > 0.5, np.ones((1,m)), np.zeros((1,m)))         这表示:         给定条件> 0.5,如果为真,则返回np.ones((1,m)),         如果为假,则返回np.zeros((1,m))。         另一种理解方式如下:        告诉我数组A中哪些条目满足条件A > 0.5,        在这些位置,给我np.ones((1,m)),否则,给我         np.zeros((1,m))        '''        ### END CODE HERE ###    assert(Y_prediction.shape == (1, m))    return Y_predictionw = np.array([[0.1124579],[0.23106775]])b = -0.3X = np.array([[1.,-1.1,-3.2],[1.2,2.,0.1]])print(sigmoid(np.dot(w.T, X) + b))print ("predictions = " + str(predict(w, b, X)))   # 输出如预期给出1,1,0

回答:

你的替代方法看起来不错。作为补充,我要说你甚至不需要np.onesnp.zeros,你可以直接指定整数01。使用np.where时,只要Xy(根据条件替换的值)和相同的条件是可广播的,就应该能正常工作。这里有一个简单的例子:

y_pred = np.random.rand(1,6).round(2)# array([[0.53, 0.54, 0.68, 0.34, 0.53, 0.46]])np.where(y_pred> 0.5, np.ones((1,6)), np.zeros((1,6)))# array([[1., 1., 1., 0., 1., 0.]])

使用整数的情况:

np.where(y_pred> 0.5,1,0)# array([[1, 1, 1, 0, 1, 0]])

关于你对函数工作原理的评论,确实如你所描述的那样工作。或许,与其说To condense this code,我认为使用numpy使其更有效,并且在这种情况下也更易于理解。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注