Closed
Description
Trying to export torchscript module for AlbertForQuestionAnswering.
self.model = AlbertForQuestionAnswering.from_pretrained(self.model_dir)
script_model = torch.jit.script(self.model)
script_model.save("script_model.pt")
Getting following exception:
Python builtin <built-in function next> is currently not supported in Torchscript:
at /usr/local/lib/python3.6/dist-packages/transformers/modeling_albert.py:523:67
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
~~~~ <--- HERE extended_attention_mask=(1.0 - extended_attention_mask) * -10000.0 if head_mask is not None: if
head_mask.dim()==1: head_mask=head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask=head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) elif head_mask.dim()==2:
head_mask=head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask=head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else: '__torch__.transformers.modeling_albert.___torch_mangle_15.AlbertModel.forward' is being compiled since it
was called from '__torch__.transformers.modeling_albert.___torch_mangle_14.AlbertForQuestionAnswering.forward'
at /usr/local/lib/python3.6/dist-packages/transformers/modeling_albert.py:767:8 def forward(self,
input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
start_positions=None, end_positions=None): outputs=self.albert( ~~~~~~~~~~~~~~~~~~~~~~... <--- HERE
input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids,
head_mask=head_mask, inputs_embeds=inputs_embeds ) sequence_output=outputs[0]