Skip to content

Commit 8ff016a

Browse files
author
Vremold
committed
[MHLO] bert-tiny and resnet18 example from torchscript to mhlo
Co-authored-by: Bairen Yi <[email protected]> Co-authored-by: Jiawei Wu <[email protected]> Co-authored-by: Tianyou Guo <[email protected]> Co-authored-by: Xu Yan <[email protected]> Co-authored-by: Ziheng Jiang <[email protected]>
1 parent 8cad02f commit 8ff016a

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import torch
2+
import torchvision.models as models
3+
import torch_mlir
4+
5+
model = models.resnet18(pretrained=True)
6+
model.eval()
7+
data = torch.randn(2,3,200,200)
8+
out_mhlo_mlir_path = "./resnet18_mhlo.mlir"
9+
10+
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=False)
11+
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
12+
outf.write(str(module))
13+
print(f"MHLO IR of resent18 successfully written into {out_mhlo_mlir_path}")
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import torch
2+
import torch_mlir
3+
4+
from transformers import BertForMaskedLM
5+
6+
# Wrap the bert model to avoid multiple returns problem
7+
class BertTinyWrapper(torch.nn.Module):
8+
def __init__(self) -> None:
9+
super().__init__()
10+
self.bert = BertForMaskedLM.from_pretrained("prajjwal1/bert-tiny", return_dict=False)
11+
12+
def forward(self, data):
13+
return self.bert(data)[0]
14+
15+
model = BertTinyWrapper()
16+
model.eval()
17+
data = torch.randint(30522, (2, 128))
18+
out_mhlo_mlir_path = "./bert_tiny_mhlo.mlir"
19+
20+
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=True)
21+
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
22+
outf.write(str(module))
23+
print(f"MHLO IR of tiny bert successfully written into {out_mhlo_mlir_path}")

0 commit comments

Comments
 (0)