Matplotlib的scatter函数现在需要三个参数?

我编写了一个简短的脚本来显示一些图表:

if __name__=='__main__':
    p1 = gen_spiral(label=0, dt=0,     n_samples=100, )
    p2 = gen_spiral(label=1, dt=np.pi, n_samples=100, )
    print ("Array:   {}\nType:   {}\nShape:   {}\nLength:   {}\nData:   {}\n".format("p1",   type(p1),   str(np.shape(p1)),   len(p1),    str(p1)))
    print ("Array:   {}\nType:   {}\nShape:   {}\nLength:   {}\nData:   {}\n".format("p2",   type(p2),   str(np.shape(p2)),   len(p2),    str(p2)))
    a = np.arange(1,20,1)
    b = np.arange(1,20,1)
    c = np.arange(0.0, 2.0, 0.01)
    d = np.sin(2*np.pi*c)
    fig1 = plt.figure()
    ax1 = fig1.add_subplot(121)
    ax1.scatter(a,b)
    ax2 = fig1.add_subplot(122)
    ax2.scatter(c,d)

脚本运行得很好。然而,当我更改函数以显示我的数据时:

fig1 = plt.figure()
ax1 = fig1.add_subplot(121)
ax1.scatter(p1)

它给出了一个不应该存在的错误:

Traceback (most recent call last):
  File "Theano--PlotSet--ME01.py", line 53, in <module>
    ax1.scatter(p1)
TypeError: scatter() takes at least 3 arguments (2 given)

这不是真的:Scatter需要3个参数,而p1有两个部分:

    jason@jason-HP-43299:~/Programs/MachineLearning/SectionOne$ python TheanoPS00.py
    Array:   p1
    Type:   <type 'numpy.ndarray'>
    Shape:   (100, 2)
    Length:   100
    Data:   [[-0.0617  0.0534]
     [ 0.0299  0.0913]
     [ 0.0094  0.157 ]
     [ 0.1057  0.1535]
     [ 0.1412  0.2741]
     [ 0.0851  0.1426]

这里到底发生了什么?


回答:

你在这里调用的scatter函数是axes类的类方法。它的签名是
scatter(self, x, y, *args, **kwargs),有3个位置参数,第一个类参数self是你通过调用axes实例的方法隐式提供的。

了解这一点后,可以将错误消息翻译为“scatter() 至少需要2个用户指定的参数(只给出了1个)”。这是有道理的,因为scatter(p1)使用了1个参数。

你需要做的是将你的数组分成xy数组,

ax1.scatter(p1[:,0], p1[:,1])

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注