使用torch-mlir加载MLIR文件作为模型

我已经按照GitHub上的示例,使用torch-mlir将一个torch模型转换为MLIR:

def save_module():
    resnet18 = models.resnet18(pretrained=True)
    resnet18.eval()

    module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
    open("resnet18torch.mlir", "w").write(str(module))

之后,我修改了MLIR,现在我想做相反的操作,在我的Python代码中将这个MLIR加载为一个模块,以便继续编译它,像这样:

src = open("resnet18torch.mlir", "r").read()
#transform src to module
backend = refbackend.RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(module)
jit_module = backend.load(compiled)
predictions(resnet18.forward, jit_module.forward, img, labels)

我只需要在第二段代码中添加一行来加载MLIR作为模块,但我在网上找不到任何相关信息。有人知道如何做吗?


回答:

我最终在LLVM discord的帮助下,弄清楚了如何加载torch mlir并如何重用它。以下是使用torch_mlir示例(torchscript_resnet18.py)的用法:

import torch_mlir
def load_module():
    #load the torch mlir
    src = open("resnet18torch.mlir", "r").read()
    with torch_mlir.ir.Context() as ctx:
        torch_mlir.dialects.torch.register_dialect(ctx)
        with torch_mlir.ir.Location.unknown() as loc:
            module = torch_mlir.ir.Module.parse(src)
    #translate the torch-mlir to the linalg-on-tensors dialect
    torch_mlir.run_pipeline_with_repro_report(module, "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)", "Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
    #print(module)
    #compile the module
    backend = refbackend.RefBackendLinalgOnTensorsBackend()
    compiled = backend.compile(module)
    jit_module = backend.load(compiled)
    return jit_module
#save_module()
jit_module = load_module()
predictions(jit_module.forward, img, labels)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注