tensorflow.js GPU崩溃

我的模型接收一张图片并计算某个值,输入层是一个裁剪层,用于从图片的顶部和底部移除一定数量的像素。模型运行得相当好,但是当我更改裁剪层的设置时,例如,将顶部移除的像素从75改为25,浏览器窗口(Chrome)会闪烁并输出以下错误:

enter image description here

注意:在上述错误之前,它会打印出以下消息 Couldn't parse line number in error,随后是看起来像是GLSL代码的内容。

如果完全移除裁剪层,也会出现相同的错误。

我使用的是tfjs v3.8.0版本,但也测试了v2.0.0版本,结果类似。这是我的模型:

const model = tf.sequential();// Cropping Layermodel.add(  tf.layers.cropping2D({    // 如果我将75改为低于50的任何值,它会在完成第一个epoch之前崩溃,    // 如果移除这一层,它几乎会在训练开始后立即崩溃    cropping: [      [75, 25],      [0, 0]    ],    // 图片高度,宽度,深度    inputShape: [160, 320, 3]  }));model.add(  tf.layers.conv2d({    filters: 16,    kernelSize: [3, 3],    strides: [2, 2],    activation: 'relu',  }));model.add(  tf.layers.maxPool2d({    poolSize: [2, 2]  }));model.add(  tf.layers.conv2d({    filters: 32,    kernelSize: [3, 3],    strides: [2, 2],    activation: 'relu'  }));model.add(  tf.layers.maxPool2d({    poolSize: [2, 2]  }));model.add( tf.layers.flatten());model.add( tf.layers.dense({ units: 1024, activation: 'relu' }));model.add( tf.layers.dropout({ rate: 0.25 }));model.add( tf.layers.dense({ units: 128, activation: 'relu' }));model.add( tf.layers.dense({ units: 1, activation: 'linear' }));model.compile({  optimizer: 'adam',  loss: 'meanSquaredError',  metrics: [    'accuracy',  ],});

我做错了什么明显的地方吗?


回答:

正如@vladimir-mandic建议的,问题最终确实是由于GPU内存不足引起的。但将WEBGL_DELETE_TEXTURE_THRESHOLD设置为零在我的情况下没有帮助。

我花了一段时间来验证这一点,因为它发生在批次之间,我无法通过tf.memory()batchEnd跟踪,因为在回调中内存已经释放,或者GPU会在到达该点之前崩溃。我最终采取了以下两个措施来克服这个问题:

  • 减小图像尺寸:裁剪层有助于避免达到“内存不足”的状态,因此移除它或减少裁剪的像素数量会导致应用程序崩溃。但由于它也分配张量,我决定在将图像输入模型之前,通过canvas操作调整图像大小。
  • 减小batchSize:我一直在使用默认的batchSize 32,直到我将其减小时,我才注意到崩溃问题消失了,这促使我调查model.fitDataset的内部工作原理,并发现了批次之间过度消耗内存的问题。

正如@vladimir-mandic推荐的,将WEBGL_DELETE_TEXTURE_THRESHOLD设置为0也应该有助于缓解这个问题,但在我的情况下我没有注意到明显的效果,所以我最终没有使用它。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注