在矩阵乘法中出现错误:内层形状(1)和(2)的张量,其形状分别为684,1和2,1,且transposeA=false和transposeB=false时必须匹配

我是人工智能和TensorFlow.js的初学者,目前正在跟随Stephen Grider的机器学习课程。运行以下代码后,本应得到输出结果,但却遇到了错误。请帮助我:

代码:linear-regression.js:

const tf = require('@tensorflow/tfjs');class LinearRegression {    constructor(features, labels, options) {        this.features = tf.tensor(features);        this.labels = tf.tensor(labels);        this.features = tf.ones([this.features.shape[0], 1]).concat(this.features) //生成马力的一列        this.options = Object.assign(            { learningRate: 0.1, iterations: 1000 },             options        ); //默认学习率为0.1,如果提供了学习率,则使用提供的值...迭代次数为梯度下降运行的次数        this.weights = tf.zeros([2, 1]); //初始权重m和b的张量为零    }    gradientDescent() {        const currentGuesses = this.features.matMul(this.weights); //matMul是矩阵乘法,即特征乘以权重        const differences = currentGuesses.sub(this.labels); //(特征 * 权重) - 标签        const slopes = this.features            .transpose()            .matMul(differences)            .div(features.shape[0]); //相对于m和b的MSE斜率。特征 * ((特征 * 权重) - 标签) / 特征总数                this.weights = this.weights.sub(slopes.mul(this.options.learningRate));    }    train() {        for (let i=0; i < this.options.iterations; i++) {            this.gradientDescent();        }        /*test(testFeatures, testLabels) {            testFeatures = tf.tensor(testFeatures);            testLabels = tf.tensor(testLabels);        } */    }}module.exports = LinearRegression;

index.js:

require('@tensorflow/tfjs-node');const tf = require('@tensorflow/tfjs');const loadCSV = require('./load-csv');const LinearRegression = require('./linear-regression');let { features, labels, testFeatures, testLabels } =loadCSV('./cars.csv', {    shuffle: true,    splitTest: 50,    dataColumns: ['horsepower'],    labelColumns: ['mpg']});const regression = new LinearRegression(features, labels, {    learningRate: 0.002,    iterations: 100});regression.train();console.log(    '更新后的M值为:',     regression.weights.get(1, 0),     '更新后的B值为:',     regression.weights.get(0, 0)    );

错误:

D:\Application Development\MLKits-master\MLKits-master\regressions\node_modules\@tensorflow\tfjs-core\dist\ops\operation.js:32            throw ex;            ^Error: 在矩阵乘法中出现错误:内层形状(1)和(2)的张量,其形状分别为684,1和2,1,且transposeA=false和transposeB=false时必须匹配。    at Object.assert (D:\Application Development\MLKits-master\MLKits-master\regressions\node_modules\@tensorflow\tfjs-core\dist\util.js:36:15)    at matMul_ (D:\Application Development\MLKits-master\MLKits-master\regressions\node_modules\@tensorflow\tfjs-core\dist\ops\matmul.js:25:10)    at Object.matMul (D:\Application Development\MLKits-master\MLKits-master\regressions\node_modules\@tensorflow\tfjs-core\dist\ops\operation.js:23:29)    at Tensor.matMul (D:\Application Development\MLKits-master\MLKits-master\regressions\node_modules\@tensorflow\tfjs-core\dist\tensor.js:315:26)    at LinearRegression.gradientDescent (D:\Application Development\MLKits-master\MLKits-master\regressions\linear-regression.js:19:46)    at LinearRegression.train (D:\Application Development\MLKits-master\MLKits-master\regressions\linear-regression.js:34:18)    at Object.<anonymous> (D:\Application Development\MLKits-master\MLKits-master\regressions\index.js:18:12)    at Module._compile (internal/modules/cjs/loader.js:1063:30)    at Object.Module._extensions..js (internal/modules/cjs/loader.js:1092:10)    at Module.load (internal/modules/cjs/loader.js:928:32)

回答:

错误是由以下代码抛出的:

this.features.matMul(this.weights)

this.features的形状为[684, 1],而this.weights的形状为[2, 1]。要使矩阵A(形状为[a, b])能够与矩阵B(形状为[c, d])相乘,bc必须匹配,但这里并不符合这一条件。

要解决这个问题,this.weights应该进行转置:

this.features.matMul(this.weights, false, true)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注