Skip to content

Commit d25a36f

Browse files
authored
Merge pull request #16 from microsoft/pr_for_running_roberta_with_ortmodule
hack to make roberta can run it ortmodule
2 parents efc9019 + b25c43e commit d25a36f

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

examples/pytorch/question-answering/run_qa.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ def main():
283283
cache_dir=model_args.cache_dir,
284284
revision=model_args.model_revision,
285285
use_auth_token=True if model_args.use_auth_token else None,
286+
ort = training_args.ort,
286287
)
287288
tokenizer = AutoTokenizer.from_pretrained(
288289
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,

src/transformers/models/roberta/modeling_roberta.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1414,6 +1414,7 @@ class RobertaForQuestionAnswering(RobertaPreTrainedModel):
14141414
def __init__(self, config):
14151415
super().__init__(config)
14161416
self.num_labels = config.num_labels
1417+
self.ort = config.ort
14171418

14181419
self.roberta = RobertaModel(config, add_pooling_layer=False)
14191420
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
@@ -1480,7 +1481,7 @@ def forward(
14801481
if len(end_positions.size()) > 1:
14811482
end_positions = end_positions.squeeze(-1)
14821483
# sometimes the start/end positions are outside our model inputs, we ignore these terms
1483-
ignored_index = start_logits.size(1)
1484+
ignored_index = start_logits.size(1) if not self.ort else 344
14841485
start_positions.clamp_(0, ignored_index)
14851486
end_positions.clamp_(0, ignored_index)
14861487

0 commit comments

Comments
 (0)