|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +from transformers import BertModel, BertTokenizer, BertConfig |
| 4 | +import torch.nn.functional as F |
| 5 | + |
| 6 | + |
| 7 | +# Sample Pool Model (for testing plugin serialization) |
| 8 | +class Pool(nn.Module): |
| 9 | + |
| 10 | + def __init__(self): |
| 11 | + super(Pool, self).__init__() |
| 12 | + |
| 13 | + def forward(self, x): |
| 14 | + return F.adaptive_avg_pool2d(x, (5, 5)) |
| 15 | + |
| 16 | + |
| 17 | +# Sample Nested Module (for module-level fallback testing) |
| 18 | +class ModuleFallbackSub(nn.Module): |
| 19 | + |
| 20 | + def __init__(self): |
| 21 | + super(ModuleFallbackSub, self).__init__() |
| 22 | + self.conv = nn.Conv2d(1, 3, 3) |
| 23 | + self.relu = nn.ReLU() |
| 24 | + |
| 25 | + def forward(self, x): |
| 26 | + return self.relu(self.conv(x)) |
| 27 | + |
| 28 | + |
| 29 | +class ModuleFallbackMain(nn.Module): |
| 30 | + |
| 31 | + def __init__(self): |
| 32 | + super(ModuleFallbackMain, self).__init__() |
| 33 | + self.layer1 = ModuleFallbackSub() |
| 34 | + self.conv = nn.Conv2d(3, 6, 3) |
| 35 | + self.relu = nn.ReLU() |
| 36 | + |
| 37 | + def forward(self, x): |
| 38 | + return self.relu(self.conv(self.layer1(x))) |
| 39 | + |
| 40 | + |
| 41 | +# Sample Looping Modules (for loop fallback testing) |
| 42 | +class LoopFallbackEval(nn.Module): |
| 43 | + |
| 44 | + def __init__(self): |
| 45 | + super(LoopFallbackEval, self).__init__() |
| 46 | + |
| 47 | + def forward(self, x): |
| 48 | + add_list = torch.empty(0).to(x.device) |
| 49 | + for i in range(x.shape[1]): |
| 50 | + add_list = torch.cat((add_list, torch.tensor([x.shape[1]]).to(x.device)), 0) |
| 51 | + return x + add_list |
| 52 | + |
| 53 | + |
| 54 | +class LoopFallbackNoEval(nn.Module): |
| 55 | + |
| 56 | + def __init__(self): |
| 57 | + super(LoopFallbackNoEval, self).__init__() |
| 58 | + |
| 59 | + def forward(self, x): |
| 60 | + for _ in range(x.shape[1]): |
| 61 | + x = x + torch.ones_like(x) |
| 62 | + return x |
| 63 | + |
| 64 | + |
| 65 | +# Sample Conditional Model (for testing partitioning and fallback in conditionals) |
| 66 | +class FallbackIf(torch.nn.Module): |
| 67 | + |
| 68 | + def __init__(self): |
| 69 | + super(FallbackIf, self).__init__() |
| 70 | + self.relu1 = torch.nn.ReLU() |
| 71 | + self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1) |
| 72 | + self.log_sig = torch.nn.LogSigmoid() |
| 73 | + self.conv2 = torch.nn.Conv2d(32, 32, 3, 1, 1) |
| 74 | + self.conv3 = torch.nn.Conv2d(32, 3, 3, 1, 1) |
| 75 | + |
| 76 | + def forward(self, x): |
| 77 | + x = self.relu1(x) |
| 78 | + x_first = x[0][0][0][0].item() |
| 79 | + if x_first > 0: |
| 80 | + x = self.conv1(x) |
| 81 | + x1 = self.log_sig(x) |
| 82 | + x2 = self.conv2(x) |
| 83 | + x = self.conv3(x1 + x2) |
| 84 | + else: |
| 85 | + x = self.log_sig(x) |
| 86 | + x = self.conv1(x) |
| 87 | + return x |
| 88 | + |
| 89 | + |
| 90 | +def BertModule(): |
| 91 | + model_name = "bert-base-uncased" |
| 92 | + enc = BertTokenizer.from_pretrained(model_name) |
| 93 | + text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]" |
| 94 | + tokenized_text = enc.tokenize(text) |
| 95 | + masked_index = 8 |
| 96 | + tokenized_text[masked_index] = "[MASK]" |
| 97 | + indexed_tokens = enc.convert_tokens_to_ids(tokenized_text) |
| 98 | + segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1] |
| 99 | + tokens_tensor = torch.tensor([indexed_tokens]) |
| 100 | + segments_tensors = torch.tensor([segments_ids]) |
| 101 | + config = BertConfig( |
| 102 | + vocab_size_or_config_json_file=32000, |
| 103 | + hidden_size=768, |
| 104 | + num_hidden_layers=12, |
| 105 | + num_attention_heads=12, |
| 106 | + intermediate_size=3072, |
| 107 | + torchscript=True, |
| 108 | + ) |
| 109 | + model = BertModel(config) |
| 110 | + model.eval() |
| 111 | + model = BertModel.from_pretrained(model_name, torchscript=True) |
| 112 | + traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors]) |
| 113 | + return traced_model |
0 commit comments