diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index 249e2d171836..b37718c0b8fe 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -936,11 +936,8 @@ def forward( def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, **kwargs): assert past is not None, "past has to be defined for encoder_outputs" - # first step, decoder_cached_states are empty - if not past[1]: - encoder_outputs, decoder_cached_states = past, None - else: - encoder_outputs, decoder_cached_states = past + encoder_outputs, decoder_cached_states = past + return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index 9ac085495f43..444c7e858d21 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -1160,11 +1160,7 @@ def forward( def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs): assert past is not None, "past has to be defined for encoder_outputs" - # first step - if len(past) < 2: - encoder_outputs, decoder_past_key_value_states = past, None - else: - encoder_outputs, decoder_past_key_value_states = past[0], past[1] + encoder_outputs, decoder_past_key_value_states = past return { "decoder_input_ids": input_ids, diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index fc75984c8bf6..97a28ff34531 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -458,6 +458,7 @@ def _do_output_past(self, outputs): def generate( self, input_ids=None, + prefix_ids=None, max_length=None, min_length=None, do_sample=None, @@ -488,15 +489,21 @@ def generate( Parameters: - input_ids: (`optional`) `tf.Tensor` of `dtype=tf.int32` of shape `(batch_size, sequence_length)` + input_ids: (`optional`) `tf.Tensor` of `dtype=tf.int32 of shape `(batch_size, sequence_length)` + Language model: alias for prefix_ids, the sequence used as a prompt for the generation. + Seq2seq model: the sequence input to the encoder model. + + prefix_ids: (`optional`) `tf.Tensor` of `dtype=tf.int32 of shape `(batch_size, sequence_length)` The sequence used as a prompt for the generation. If `None` the method initializes it as an empty `torch.LongTensor` of shape `(1,)`. + Language model: If `None`, uses input_ids. max_length: (`optional`) int The max length of the sequence to be generated. Between 1 and infinity. Default to 20. min_length: (`optional`) int The min length of the sequence to be generated. Between 0 and infinity. Default to 0. + do_sample: (`optional`) bool If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`. @@ -601,6 +608,17 @@ def generate( "Please use another model class (e.g. `TFOpenAIGPTLMHeadModel`, `TFXLNetLMHeadModel`, `TFGPT2LMHeadModel`, `TFCTRLLMHeadModel`, `TFT5ForConditionalGeneration`, `TFTransfoXLLMHeadModel`)" ) + if self.config.is_encoder_decoder: + bos_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id + bos_token_id = bos_token_id if bos_token_id is not None else self.config.decoder_start_token_id + bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id + + assert ( + bos_token_id is not None + ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" + else: + prefix_ids = prefix_ids if prefix_ids is not None else input_ids + max_length = max_length if max_length is not None else self.config.max_length min_length = min_length if min_length is not None else self.config.min_length do_sample = do_sample if do_sample is not None else self.config.do_sample @@ -621,12 +639,11 @@ def generate( num_return_sequences = ( num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences ) - decoder_start_token_id = ( - decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id - ) if input_ids is not None: - batch_size = shape_list(input_ids)[0] # overriden by the input batch_size + batch_size = input_ids.shape[0] # overriden by the input batch_size + elif prefix_ids is not None: + batch_size = prefix_ids.shape[0] else: batch_size = 1 @@ -639,7 +656,7 @@ def generate( assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." - assert input_ids is not None or ( + assert prefix_ids is not None or ( isinstance(bos_token_id, int) and bos_token_id >= 0 ), "If input_ids is not defined, `bos_token_id` should be a positive integer." assert pad_token_id is None or ( @@ -656,14 +673,16 @@ def generate( bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list) ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated" - if input_ids is None: + if prefix_ids is None: assert isinstance(bos_token_id, int) and bos_token_id >= 0, ( "you should either supply a context to complete as `input_ids` input " "or a `bos_token_id` (integer >= 0) as a first token to start the generation." ) - input_ids = tf.fill((batch_size, 1), bos_token_id) + prefix_ids = tf.fill((batch_size, 1), bos_token_id) + if input_ids is None: + input_ids = tf.fill((batch_size, 1), bos_token_id) else: - assert len(shape_list(input_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)." + assert len(shape_list(prefix_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)." # not allow to duplicate outputs when greedy decoding if do_sample is False: @@ -693,7 +712,7 @@ def generate( pad_token_id = eos_token_id # current position and vocab size - cur_len = shape_list(input_ids)[1] + cur_len = shape_list(prefix_ids)[1] vocab_size = self.config.vocab_size # set effective batch size and effective batch multiplier according to do_sample @@ -707,45 +726,41 @@ def generate( # Expand input ids if num_beams > 1 or num_return_sequences > 1 if num_return_sequences > 1 or num_beams > 1: input_ids_len = shape_list(input_ids)[-1] - input_ids = tf.broadcast_to( - tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len) + prefix_ids_len = shape_list(prefix_ids)[-1] + prefix_ids = tf.broadcast_to( + tf.expand_dims(prefix_ids, 1), (batch_size, effective_batch_mult * num_beams, prefix_ids_len) ) attention_mask = tf.broadcast_to( tf.expand_dims(attention_mask, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len) ) - input_ids = tf.reshape( - input_ids, (effective_batch_size * num_beams, input_ids_len) + prefix_ids = tf.reshape( + prefix_ids, (effective_batch_size * num_beams, prefix_ids_len) ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) attention_mask = tf.reshape( attention_mask, (effective_batch_size * num_beams, input_ids_len) ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) + if self.config.is_encoder_decoder: + input_ids = tf.broadcast_to( + tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len) + ) + input_ids = tf.reshape( + input_ids, (effective_batch_size * num_beams, input_ids_len) + ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) + # If encoder-decoder, get encoder outputs if self.config.is_encoder_decoder: - if decoder_start_token_id is None: - decoder_start_token_id = bos_token_id - - assert ( - decoder_start_token_id is not None - ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self) assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder) # get encoder and store encoder outputs encoder = self.get_encoder() - encoder_outputs = encoder(input_ids, attention_mask=attention_mask) - - # create empty decoder_input_ids - input_ids = tf.ones((effective_batch_size * num_beams, 1), dtype=tf.int32,) * decoder_start_token_id - cur_len = 1 - else: encoder_outputs = None - cur_len = shape_list(input_ids)[-1] if num_beams > 1: output = self._generate_beam_search( - input_ids, + input_ids=prefix_ids, cur_len=cur_len, max_length=max_length, min_length=min_length, @@ -771,7 +786,7 @@ def generate( ) else: output = self._generate_no_beam_search( - input_ids, + input_ids=prefix_ids, cur_len=cur_len, max_length=max_length, min_length=min_length, @@ -1007,6 +1022,7 @@ def _generate_beam_search( while cur_len < max_length: model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask) + print(model_inputs) outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size) next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f52e2b6fa20c..47d776f1f1d5 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -658,6 +658,7 @@ def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output def generate( self, input_ids=None, + prefix_ids=None, max_length=None, min_length=None, do_sample=None, @@ -688,8 +689,13 @@ def generate( Parameters: input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)` + Language model: alias for prefix_ids, the sequence used as a prompt for the generation. + Seq2seq model: the sequence input to the encoder model. + + prefix_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)` The sequence used as a prompt for the generation. If `None` the method initializes it as an empty `torch.LongTensor` of shape `(1,)`. + Language model: If `None`, uses input_ids. max_length: (`optional`) int The max length of the sequence to be generated. Between `min_length` and infinity. Default to 20. @@ -732,8 +738,10 @@ def generate( no_repeat_ngram_size: (`optional`) int If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once. + bad_words_ids: (`optional`) list of lists of int - `bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`. + `bad_words_ids` contains tokens that are not allowed to be generated. + In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`. num_return_sequences: (`optional`) int The number of independently computed returned sequences for each element in the batch. Default to 1. @@ -800,6 +808,17 @@ def generate( "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )" ) + if self.config.is_encoder_decoder: + bos_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id + bos_token_id = bos_token_id if bos_token_id is not None else self.config.decoder_start_token_id + bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id + + assert ( + bos_token_id is not None + ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" + else: + prefix_ids = prefix_ids if prefix_ids is not None else input_ids + max_length = max_length if max_length is not None else self.config.max_length min_length = min_length if min_length is not None else self.config.min_length do_sample = do_sample if do_sample is not None else self.config.do_sample @@ -820,12 +839,11 @@ def generate( num_return_sequences = ( num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences ) - decoder_start_token_id = ( - decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id - ) if input_ids is not None: batch_size = input_ids.shape[0] # overriden by the input batch_size + elif prefix_ids is not None: + batch_size = prefix_ids.shape[0] else: batch_size = 1 @@ -838,9 +856,9 @@ def generate( assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." - assert input_ids is not None or ( + assert prefix_ids is not None or ( isinstance(bos_token_id, int) and bos_token_id >= 0 - ), "If input_ids is not defined, `bos_token_id` should be a positive integer." + ), "If input_ids and prefix_ids are not defined, `bos_token_id` should be a positive integer." assert pad_token_id is None or ( isinstance(pad_token_id, int) and (pad_token_id >= 0) ), "`pad_token_id` should be a positive integer." @@ -858,16 +876,20 @@ def generate( bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list) ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated" - if input_ids is None: + if prefix_ids is None: assert isinstance(bos_token_id, int) and bos_token_id >= 0, ( - "you should either supply a context to complete as `input_ids` input " + "you should either supply a context to complete as `input_ids` or `prefix_ids` input " "or a `bos_token_id` (integer >= 0) as a first token to start the generation." ) - input_ids = torch.full( + prefix_ids = torch.full( (batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device, ) + if input_ids is None: + input_ids = torch.full( + (batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device, + ) else: - assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)." + assert prefix_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)." # not allow to duplicate outputs when greedy decoding if do_sample is False: @@ -909,50 +931,34 @@ def generate( effective_batch_size = batch_size effective_batch_mult = 1 + # If encoder-decoder, get encoder outputs before expanding attention mask if self.config.is_encoder_decoder: - if decoder_start_token_id is None: - decoder_start_token_id = bos_token_id - - assert ( - decoder_start_token_id is not None - ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self) assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder) # get encoder and store encoder outputs encoder = self.get_encoder() - encoder_outputs = encoder(input_ids, attention_mask=attention_mask) + else: + encoder_outputs = None # Expand input ids if num_beams > 1 or num_return_sequences > 1 if num_return_sequences > 1 or num_beams > 1: input_ids_len = input_ids.shape[-1] - input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len) + prefix_ids_len = prefix_ids.shape[-1] # different in the encoder-decoder setting + prefix_ids = prefix_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, prefix_ids_len) attention_mask = attention_mask.unsqueeze(1).expand( batch_size, effective_batch_mult * num_beams, input_ids_len ) - input_ids = input_ids.contiguous().view( - effective_batch_size * num_beams, input_ids_len + prefix_ids = prefix_ids.contiguous().view( + effective_batch_size * num_beams, prefix_ids_len ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) attention_mask = attention_mask.contiguous().view( effective_batch_size * num_beams, input_ids_len ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) if self.config.is_encoder_decoder: - # create empty decoder_input_ids - input_ids = torch.full( - (effective_batch_size * num_beams, 1), - decoder_start_token_id, - dtype=torch.long, - device=next(self.parameters()).device, - ) - cur_len = 1 - - assert ( - batch_size == encoder_outputs[0].shape[0] - ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} " - # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1) expanded_batch_idxs = ( torch.arange(batch_size) @@ -964,13 +970,11 @@ def generate( # expand encoder_outputs encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:]) - else: - encoder_outputs = None - cur_len = input_ids.shape[-1] + cur_len = prefix_ids.shape[-1] if num_beams > 1: output = self._generate_beam_search( - input_ids, + input_ids=prefix_ids, cur_len=cur_len, max_length=max_length, min_length=min_length, @@ -996,7 +1000,7 @@ def generate( ) else: output = self._generate_no_beam_search( - input_ids, + input_ids=prefix_ids, cur_len=cur_len, max_length=max_length, min_length=min_length, @@ -1046,7 +1050,9 @@ def _generate_no_beam_search( unfinished_sents = input_ids.new(batch_size).fill_(1) sent_lengths = input_ids.new(batch_size).fill_(max_length) - past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models + past = ( + (encoder_outputs, None) if encoder_outputs is not None else None + ) # defined for encoder-decoder models, None for decoder-only models while cur_len < max_length: model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask) @@ -1179,7 +1185,9 @@ def _generate_beam_search( beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,) # cache compute states - past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models + past = ( + (encoder_outputs, None) if encoder_outputs is not None else None + ) # defined for encoder-decoder models, None for decoder-only models # done sentences done = [False for _ in range(batch_size)] @@ -1334,7 +1342,7 @@ def _generate_beam_search( input_ids = input_ids[beam_idx, :] input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1) # re-order internal states - if past is not None: + if self._do_output_past(outputs): past = self._reorder_cache(past, beam_idx) # extend attention_mask for new generated input if only decoder