为了可视化我的线性回归模型的梯度下降,我试图为以下mse
函数绘制等高线图:
import jax.numpy as jnpimport numpy as npdef make_mse(x, t): def mse(w,b): return np.sum(jnp.power(x.dot(w) + b - t, 2))/2 return mse
其中图表的x
和y
轴分别对应w
和b
参数。
x
和t
对于图表来说并不重要,因为x
的值每次只是被w
的一个单一值相乘。
我尝试做如下操作:
x = np.linspace(-1.0,1.0,500)t = 5*x + 1xcoord = np.linspace(-10.0,10.0,50)ycoord = np.linspace(-10.0,10.0,50)w1,w2 = np.meshgrid(xcoord,ycoord)Z = make_mse(x, t)(w1,w2)
然而,我得到了明显的点积错误:
/usr/local/lib/python3.7/dist-packages/jax/_src/lax/lax.py in dot(lhs, rhs, precision, preferred_element_type) 634 else: 635 raise TypeError("Incompatible shapes for dot: got {} and {}.".format(--> 636 lhs.shape, rhs.shape)) 637 638 TypeError: Incompatible shapes for dot: got (1000, 1) and (50, 50).
有什么Pythonic且高效的方法可以绘制这个函数的等高线图吗?
回答:
你不需要np.sum()
,因为你想要的是每个网格点的MSE值,而不是它们的总和。此外,x
的维度必须与网格匹配。以下方法有效:
import numpy as npdef make_mse(x, t): def mse(w,b): return np.power(x.dot(w) + b - t, 2) return mse x = np.linspace(-1.0,1.0,500)t = 5*x + 1xcoord = np.linspace(-10.0,10.0,500)ycoord = np.linspace(-10.0,10.0,500)w1,w2 = np.meshgrid(xcoord,ycoord)Z = make_mse(x, t)(w1,w2)plt.contourf(w1,w2,Z)
以下是生成的等高线图输出