Skip to content

Commit 2e81b9d

Browse files
authored
Bart: update example for #3140 compatibility (#3233)
* Update bart example docs
1 parent 72768b6 commit 2e81b9d

File tree

3 files changed

+22
-4
lines changed

3 files changed

+22
-4
lines changed

examples/summarization/bart/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
### Get the CNN/Daily Mail Data
1+
### Get the CNN Data
22
To be able to reproduce the authors' results on the CNN/Daily Mail dataset you first need to download both CNN and Daily Mail datasets [from Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the links next to "Stories") in the same folder. Then uncompress the archives by running:
33

44
```bash
@@ -32,6 +32,7 @@ unzip stanford-corenlp-full-2018-10-05.zip
3232
cd stanford-corenlp-full-2018-10-05
3333
export CLASSPATH=stanford-corenlp-3.9.2.jar:stanford-corenlp-3.9.2-models.jar
3434
```
35+
Then run `ptb_tokenize` on `test.target` and your generated hypotheses.
3536
### Rouge Setup
3637
Install `files2rouge` following the instructions at [here](https://github.com/pltrdy/files2rouge).
3738
I also needed to run `sudo apt-get install libxml-parser-perl`

examples/summarization/bart/evaluate_cnn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
2727
attention_mask=dct["attention_mask"].to(device),
2828
num_beams=4,
2929
length_penalty=2.0,
30-
max_length=140,
31-
min_length=55,
30+
max_length=142, # +2 from original because we start at step=1 and stop before max_length
31+
min_length=56, # +1 from original because we start at step=1
3232
no_repeat_ngram_size=3,
33+
early_stopping=True,
34+
do_sample=False,
3335
)
3436
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
3537
for hypothesis in dec:

src/transformers/modeling_bart.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,20 @@
4545
Initializing with a config file does not load the weights associated with the model, only the configuration.
4646
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
4747
48+
"""
49+
BART_GENERATION_EXAMPLE = r"""
50+
Examples::
51+
52+
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
53+
# see ``examples/summarization/bart/evaluate_cnn.py`` for a longer example
54+
model = BartForConditionalGeneration.from_pretrained('bart-large-cnn')
55+
tokenizer = BartTokenizer.from_pretrained('bart-large-cnn')
56+
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
57+
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
58+
# Generate Summary
59+
summary_ids = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], num_beams=4, max_length=5)
60+
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
61+
4862
"""
4963

5064
BART_INPUTS_DOCSTRING = r"""
@@ -855,7 +869,8 @@ def get_output_embeddings(self):
855869

856870

857871
@add_start_docstrings(
858-
"The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING,
872+
"The BART Model with a language modeling head. Can be used for summarization.",
873+
BART_START_DOCSTRING + BART_GENERATION_EXAMPLE,
859874
)
860875
class BartForConditionalGeneration(PretrainedBartModel):
861876
base_model_prefix = "model"

0 commit comments

Comments
 (0)