|
| 1 | +from typing import Optional |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import torch |
| 5 | +from torch import nn |
| 6 | +from transformers import GPT2Config, GPT2LMHeadModel |
| 7 | +from transformers.modeling_utils import ModuleUtilsMixin |
| 8 | + |
| 9 | +from ...configuration_utils import ConfigMixin, register_to_config |
| 10 | +from ...models import ModelMixin |
| 11 | + |
| 12 | + |
| 13 | +# Modified from ClipCaptionModel in https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py |
| 14 | +class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): |
| 15 | + """ |
| 16 | + Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to |
| 17 | + generate text from the UniDiffuser image-text embedding. |
| 18 | +
|
| 19 | + Parameters: |
| 20 | + prefix_length (`int`): |
| 21 | + Max number of prefix tokens that will be supplied to the model. |
| 22 | + prefix_inner_dim (`int`): |
| 23 | + The hidden size of the the incoming prefix embeddings. For UniDiffuser, this would be the hidden dim of the |
| 24 | + CLIP text encoder. |
| 25 | + prefix_hidden_dim (`int`, *optional*): |
| 26 | + Hidden dim of the MLP if we encode the prefix. |
| 27 | + vocab_size (`int`, *optional*, defaults to 50257): |
| 28 | + Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the |
| 29 | + `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. |
| 30 | + n_positions (`int`, *optional*, defaults to 1024): |
| 31 | + The maximum sequence length that this model might ever be used with. Typically set this to something large |
| 32 | + just in case (e.g., 512 or 1024 or 2048). |
| 33 | + n_embd (`int`, *optional*, defaults to 768): |
| 34 | + Dimensionality of the embeddings and hidden states. |
| 35 | + n_layer (`int`, *optional*, defaults to 12): |
| 36 | + Number of hidden layers in the Transformer encoder. |
| 37 | + n_head (`int`, *optional*, defaults to 12): |
| 38 | + Number of attention heads for each attention layer in the Transformer encoder. |
| 39 | + n_inner (`int`, *optional*, defaults to None): |
| 40 | + Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd |
| 41 | + activation_function (`str`, *optional*, defaults to `"gelu"`): |
| 42 | + Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. |
| 43 | + resid_pdrop (`float`, *optional*, defaults to 0.1): |
| 44 | + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. |
| 45 | + embd_pdrop (`float`, *optional*, defaults to 0.1): |
| 46 | + The dropout ratio for the embeddings. |
| 47 | + attn_pdrop (`float`, *optional*, defaults to 0.1): |
| 48 | + The dropout ratio for the attention. |
| 49 | + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): |
| 50 | + The epsilon to use in the layer normalization layers. |
| 51 | + initializer_range (`float`, *optional*, defaults to 0.02): |
| 52 | + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. |
| 53 | + scale_attn_weights (`bool`, *optional*, defaults to `True`): |
| 54 | + Scale attention weights by dividing by sqrt(hidden_size).. |
| 55 | + use_cache (`bool`, *optional*, defaults to `True`): |
| 56 | + Whether or not the model should return the last key/values attentions (not used by all models). |
| 57 | + scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): |
| 58 | + Whether to additionally scale attention weights by `1 / layer_idx + 1`. |
| 59 | + reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): |
| 60 | + Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention |
| 61 | + dot-product/softmax to float() when training with mixed precision. |
| 62 | + """ |
| 63 | + |
| 64 | + @register_to_config |
| 65 | + def __init__( |
| 66 | + self, |
| 67 | + prefix_length: int, |
| 68 | + prefix_inner_dim: int, |
| 69 | + prefix_hidden_dim: Optional[int] = None, |
| 70 | + vocab_size: int = 50257, # Start of GPT2 config args |
| 71 | + n_positions: int = 1024, |
| 72 | + n_embd: int = 768, |
| 73 | + n_layer: int = 12, |
| 74 | + n_head: int = 12, |
| 75 | + n_inner: Optional[int] = None, |
| 76 | + activation_function: str = "gelu_new", |
| 77 | + resid_pdrop: float = 0.1, |
| 78 | + embd_pdrop: float = 0.1, |
| 79 | + attn_pdrop: float = 0.1, |
| 80 | + layer_norm_epsilon: float = 1e-5, |
| 81 | + initializer_range: float = 0.02, |
| 82 | + scale_attn_weights: bool = True, |
| 83 | + use_cache: bool = True, |
| 84 | + scale_attn_by_inverse_layer_idx: bool = False, |
| 85 | + reorder_and_upcast_attn: bool = False, |
| 86 | + ): |
| 87 | + super().__init__() |
| 88 | + |
| 89 | + self.prefix_length = prefix_length |
| 90 | + |
| 91 | + if prefix_inner_dim != n_embd and prefix_hidden_dim is None: |
| 92 | + raise ValueError( |
| 93 | + f"`prefix_hidden_dim` cannot be `None` when `prefix_inner_dim`: {prefix_hidden_dim} and" |
| 94 | + f" `n_embd`: {n_embd} are not equal." |
| 95 | + ) |
| 96 | + |
| 97 | + self.prefix_inner_dim = prefix_inner_dim |
| 98 | + self.prefix_hidden_dim = prefix_hidden_dim |
| 99 | + |
| 100 | + self.encode_prefix = ( |
| 101 | + nn.Linear(self.prefix_inner_dim, self.prefix_hidden_dim) |
| 102 | + if self.prefix_hidden_dim is not None |
| 103 | + else nn.Identity() |
| 104 | + ) |
| 105 | + self.decode_prefix = ( |
| 106 | + nn.Linear(self.prefix_hidden_dim, n_embd) if self.prefix_hidden_dim is not None else nn.Identity() |
| 107 | + ) |
| 108 | + |
| 109 | + gpt_config = GPT2Config( |
| 110 | + vocab_size=vocab_size, |
| 111 | + n_positions=n_positions, |
| 112 | + n_embd=n_embd, |
| 113 | + n_layer=n_layer, |
| 114 | + n_head=n_head, |
| 115 | + n_inner=n_inner, |
| 116 | + activation_function=activation_function, |
| 117 | + resid_pdrop=resid_pdrop, |
| 118 | + embd_pdrop=embd_pdrop, |
| 119 | + attn_pdrop=attn_pdrop, |
| 120 | + layer_norm_epsilon=layer_norm_epsilon, |
| 121 | + initializer_range=initializer_range, |
| 122 | + scale_attn_weights=scale_attn_weights, |
| 123 | + use_cache=use_cache, |
| 124 | + scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, |
| 125 | + reorder_and_upcast_attn=reorder_and_upcast_attn, |
| 126 | + ) |
| 127 | + self.transformer = GPT2LMHeadModel(gpt_config) |
| 128 | + |
| 129 | + def forward( |
| 130 | + self, |
| 131 | + input_ids: torch.Tensor, |
| 132 | + prefix_embeds: torch.Tensor, |
| 133 | + attention_mask: Optional[torch.Tensor] = None, |
| 134 | + labels: Optional[torch.Tensor] = None, |
| 135 | + ): |
| 136 | + """ |
| 137 | + Args: |
| 138 | + input_ids (`torch.Tensor` of shape `(N, max_seq_len)`): |
| 139 | + Text tokens to use for inference. |
| 140 | + prefix_embeds (`torch.Tensor` of shape `(N, prefix_length, 768)`): |
| 141 | + Prefix embedding to preprend to the embedded tokens. |
| 142 | + attention_mask (`torch.Tensor` of shape `(N, prefix_length + max_seq_len, 768)`, *optional*): |
| 143 | + Attention mask for the prefix embedding. |
| 144 | + labels (`torch.Tensor`, *optional*): |
| 145 | + Labels to use for language modeling. |
| 146 | + """ |
| 147 | + embedding_text = self.transformer.transformer.wte(input_ids) |
| 148 | + hidden = self.encode_prefix(prefix_embeds) |
| 149 | + prefix_embeds = self.decode_prefix(hidden) |
| 150 | + embedding_cat = torch.cat((prefix_embeds, embedding_text), dim=1) |
| 151 | + |
| 152 | + if labels is not None: |
| 153 | + dummy_token = self.get_dummy_token(input_ids.shape[0], input_ids.device) |
| 154 | + labels = torch.cat((dummy_token, input_ids), dim=1) |
| 155 | + out = self.transformer(inputs_embeds=embedding_cat, labels=labels, attention_mask=attention_mask) |
| 156 | + if self.prefix_hidden_dim is not None: |
| 157 | + return out, hidden |
| 158 | + else: |
| 159 | + return out |
| 160 | + |
| 161 | + def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor: |
| 162 | + return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device) |
| 163 | + |
| 164 | + def encode(self, prefix): |
| 165 | + return self.encode_prefix(prefix) |
| 166 | + |
| 167 | + @torch.no_grad() |
| 168 | + def generate_captions(self, features, eos_token_id, device): |
| 169 | + """ |
| 170 | + Generate captions given text embedding features. Returns list[L]. |
| 171 | +
|
| 172 | + Args: |
| 173 | + features (`torch.Tensor` of shape `(B, L, D)`): |
| 174 | + Text embedding features to generate captions from. |
| 175 | + eos_token_id (`int`): |
| 176 | + The token ID of the EOS token for the text decoder model. |
| 177 | + device: |
| 178 | + Device to perform text generation on. |
| 179 | +
|
| 180 | + Returns: |
| 181 | + `List[str]`: A list of strings generated from the decoder model. |
| 182 | + """ |
| 183 | + |
| 184 | + features = torch.split(features, 1, dim=0) |
| 185 | + generated_tokens = [] |
| 186 | + generated_seq_lengths = [] |
| 187 | + for feature in features: |
| 188 | + feature = self.decode_prefix(feature.to(device)) # back to the clip feature |
| 189 | + # Only support beam search for now |
| 190 | + output_tokens, seq_lengths = self.generate_beam( |
| 191 | + input_embeds=feature, device=device, eos_token_id=eos_token_id |
| 192 | + ) |
| 193 | + generated_tokens.append(output_tokens[0]) |
| 194 | + generated_seq_lengths.append(seq_lengths[0]) |
| 195 | + generated_tokens = torch.stack(generated_tokens) |
| 196 | + generated_seq_lengths = torch.stack(generated_seq_lengths) |
| 197 | + return generated_tokens, generated_seq_lengths |
| 198 | + |
| 199 | + @torch.no_grad() |
| 200 | + def generate_beam( |
| 201 | + self, |
| 202 | + input_ids=None, |
| 203 | + input_embeds=None, |
| 204 | + device=None, |
| 205 | + beam_size: int = 5, |
| 206 | + entry_length: int = 67, |
| 207 | + temperature: float = 1.0, |
| 208 | + eos_token_id: Optional[int] = None, |
| 209 | + ): |
| 210 | + """ |
| 211 | + Generates text using the given tokenizer and text prompt or token embedding via beam search. This |
| 212 | + implementation is based on the beam search implementation from the [original UniDiffuser |
| 213 | + code](https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py#L89). |
| 214 | +
|
| 215 | + Args: |
| 216 | + eos_token_id (`int`, *optional*): |
| 217 | + The token ID of the EOS token for the text decoder model. |
| 218 | + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): |
| 219 | + Tokenizer indices of input sequence tokens in the vocabulary. One of `input_ids` and `input_embeds` |
| 220 | + must be supplied. |
| 221 | + input_embeds (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): |
| 222 | + An embedded representation to directly pass to the transformer as a prefix for beam search. One of |
| 223 | + `input_ids` and `input_embeds` must be supplied. |
| 224 | + device: |
| 225 | + The device to perform beam search on. |
| 226 | + beam_size (`int`, *optional*, defaults to `5`): |
| 227 | + The number of best states to store during beam search. |
| 228 | + entry_length (`int`, *optional*, defaults to `67`): |
| 229 | + The number of iterations to run beam search. |
| 230 | + temperature (`float`, *optional*, defaults to 1.0): |
| 231 | + The temperature to use when performing the softmax over logits from the decoding model. |
| 232 | +
|
| 233 | + Returns: |
| 234 | + `Tuple(torch.Tensor, torch.Tensor)`: A tuple of tensors where the first element is a tensor of generated |
| 235 | + token sequences sorted by score in descending order, and the second element is the sequence lengths |
| 236 | + corresponding to those sequences. |
| 237 | + """ |
| 238 | + # Generates text until stop_token is reached using beam search with the desired beam size. |
| 239 | + stop_token_index = eos_token_id |
| 240 | + tokens = None |
| 241 | + scores = None |
| 242 | + seq_lengths = torch.ones(beam_size, device=device, dtype=torch.int) |
| 243 | + is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool) |
| 244 | + |
| 245 | + if input_embeds is not None: |
| 246 | + generated = input_embeds |
| 247 | + else: |
| 248 | + generated = self.transformer.transformer.wte(input_ids) |
| 249 | + |
| 250 | + for i in range(entry_length): |
| 251 | + outputs = self.transformer(inputs_embeds=generated) |
| 252 | + logits = outputs.logits |
| 253 | + logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) |
| 254 | + logits = logits.softmax(-1).log() |
| 255 | + |
| 256 | + if scores is None: |
| 257 | + scores, next_tokens = logits.topk(beam_size, -1) |
| 258 | + generated = generated.expand(beam_size, *generated.shape[1:]) |
| 259 | + next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) |
| 260 | + if tokens is None: |
| 261 | + tokens = next_tokens |
| 262 | + else: |
| 263 | + tokens = tokens.expand(beam_size, *tokens.shape[1:]) |
| 264 | + tokens = torch.cat((tokens, next_tokens), dim=1) |
| 265 | + else: |
| 266 | + logits[is_stopped] = -float(np.inf) |
| 267 | + logits[is_stopped, 0] = 0 |
| 268 | + scores_sum = scores[:, None] + logits |
| 269 | + seq_lengths[~is_stopped] += 1 |
| 270 | + scores_sum_average = scores_sum / seq_lengths[:, None] |
| 271 | + scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1) |
| 272 | + next_tokens_source = next_tokens // scores_sum.shape[1] |
| 273 | + seq_lengths = seq_lengths[next_tokens_source] |
| 274 | + next_tokens = next_tokens % scores_sum.shape[1] |
| 275 | + next_tokens = next_tokens.unsqueeze(1) |
| 276 | + tokens = tokens[next_tokens_source] |
| 277 | + tokens = torch.cat((tokens, next_tokens), dim=1) |
| 278 | + generated = generated[next_tokens_source] |
| 279 | + scores = scores_sum_average * seq_lengths |
| 280 | + is_stopped = is_stopped[next_tokens_source] |
| 281 | + |
| 282 | + next_token_embed = self.transformer.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1) |
| 283 | + generated = torch.cat((generated, next_token_embed), dim=1) |
| 284 | + is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze() |
| 285 | + if is_stopped.all(): |
| 286 | + break |
| 287 | + |
| 288 | + scores = scores / seq_lengths |
| 289 | + order = scores.argsort(descending=True) |
| 290 | + # tokens tensors are already padded to max_seq_length |
| 291 | + output_texts = [tokens[i] for i in order] |
| 292 | + output_texts = torch.stack(output_texts, dim=0) |
| 293 | + seq_lengths = torch.tensor([seq_lengths[i] for i in order], dtype=seq_lengths.dtype) |
| 294 | + return output_texts, seq_lengths |
0 commit comments