|
| 1 | +# |
| 2 | +# Copyright 2016 The BigDL Authors. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +# |
| 16 | +# Some parts of this file is adapted from |
| 17 | +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py |
| 18 | +# which is licensed under Apache License 2.0: |
| 19 | +# |
| 20 | +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. |
| 21 | +# |
| 22 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 23 | +# you may not use this file except in compliance with the License. |
| 24 | +# You may obtain a copy of the License at |
| 25 | +# |
| 26 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 27 | +# |
| 28 | +# Unless required by applicable law or agreed to in writing, software |
| 29 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 30 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 31 | +# See the License for the specific language governing permissions and |
| 32 | +# limitations under the License. |
| 33 | + |
| 34 | + |
| 35 | +import torch |
| 36 | +import importlib |
| 37 | +from torch import nn |
| 38 | +from typing import Optional, Tuple, List |
| 39 | +from transformers.models.clip.modeling_clip import CLIPAttention |
| 40 | +from ipex_llm.utils.common.log4Error import invalidInputError |
| 41 | + |
| 42 | + |
| 43 | +def merge_qkv(module: torch.nn.Module): |
| 44 | + if isinstance(module, CLIPAttention): |
| 45 | + new_weight = torch.cat([ |
| 46 | + module.q_proj.weight.data, |
| 47 | + module.k_proj.weight.data, |
| 48 | + module.v_proj.weight.data, |
| 49 | + ], dim=0) |
| 50 | + |
| 51 | + if module.q_proj.bias is not None: |
| 52 | + qkv_proj = torch.nn.Linear(0, 0, bias=True) |
| 53 | + new_bias = torch.cat([ |
| 54 | + module.q_proj.bias.data, |
| 55 | + module.k_proj.bias.data, |
| 56 | + module.v_proj.bias.data, |
| 57 | + ], dim=0) |
| 58 | + qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False) |
| 59 | + else: |
| 60 | + qkv_proj = torch.nn.Linear(0, 0, bias=False) |
| 61 | + qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False) |
| 62 | + qkv_proj.in_features = new_weight.size(1) |
| 63 | + qkv_proj.out_features = new_weight.size(0) |
| 64 | + module.qkv_proj = qkv_proj |
| 65 | + |
| 66 | + del module.q_proj, module.k_proj, module.v_proj |
| 67 | + |
| 68 | + |
| 69 | +def phi3v_model_forward( |
| 70 | + self, |
| 71 | + input_ids: torch.LongTensor = None, |
| 72 | + attention_mask: Optional[torch.Tensor] = None, |
| 73 | + position_ids: Optional[torch.LongTensor] = None, |
| 74 | + past_key_values: Optional[List[torch.FloatTensor]] = None, |
| 75 | + inputs_embeds: Optional[torch.FloatTensor] = None, |
| 76 | + pixel_values: Optional[torch.FloatTensor] = None, |
| 77 | + image_sizes: Optional[torch.LongTensor] = None, |
| 78 | + use_cache: Optional[bool] = None, |
| 79 | + output_attentions: Optional[bool] = None, |
| 80 | + output_hidden_states: Optional[bool] = None, |
| 81 | + return_dict: Optional[bool] = None, |
| 82 | +): |
| 83 | + # ipex-llm changes start |
| 84 | + from ipex_llm.transformers.kv import DynamicNormalCache |
| 85 | + # IPEX-LLM OPT: kv cache and quantize kv cache |
| 86 | + use_cache = use_cache if use_cache is not None else self.config.use_cache |
| 87 | + if use_cache: |
| 88 | + if not isinstance(past_key_values, DynamicNormalCache): |
| 89 | + past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) |
| 90 | + modeling_module_name = self.__class__.__module__ |
| 91 | + module = importlib.import_module(modeling_module_name) |
| 92 | + return module.Phi3VModel.forward( |
| 93 | + self=self, |
| 94 | + input_ids=input_ids, |
| 95 | + attention_mask=attention_mask, |
| 96 | + position_ids=position_ids, |
| 97 | + past_key_values=past_key_values, |
| 98 | + inputs_embeds=inputs_embeds, |
| 99 | + pixel_values=pixel_values, |
| 100 | + image_sizes=image_sizes, |
| 101 | + use_cache=use_cache, |
| 102 | + output_attentions=output_attentions, |
| 103 | + output_hidden_states=output_hidden_states, |
| 104 | + return_dict=return_dict, |
| 105 | + ) |
| 106 | + |
| 107 | + |
| 108 | +def phi3v_encoder_attention_forward( |
| 109 | + self, |
| 110 | + hidden_states: torch.Tensor, |
| 111 | + attention_mask: Optional[torch.Tensor] = None, |
| 112 | + causal_attention_mask: Optional[torch.Tensor] = None, |
| 113 | + output_attentions: Optional[bool] = False, |
| 114 | +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| 115 | + bsz, tgt_len, embed_dim = hidden_states.size() |
| 116 | + |
| 117 | + qkv = self.qkv_proj(hidden_states) |
| 118 | + qkv = qkv.view(bsz, tgt_len, self.num_heads * 3, self.head_dim) |
| 119 | + qkv = qkv.transpose(1, 2) |
| 120 | + query_states, key_states, value_states = qkv.split([self.num_heads, |
| 121 | + self.num_heads, |
| 122 | + self.num_heads], dim=1) |
| 123 | + |
| 124 | + proj_shape = (bsz * self.num_heads, -1, self.head_dim) |
| 125 | + query_states = query_states.reshape(*proj_shape) |
| 126 | + key_states = key_states.reshape(*proj_shape) |
| 127 | + value_states = value_states.reshape(*proj_shape) |
| 128 | + |
| 129 | + src_len = key_states.size(1) |
| 130 | + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) |
| 131 | + |
| 132 | + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): |
| 133 | + invalidInputError( |
| 134 | + False, |
| 135 | + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}," |
| 136 | + f" but is {attn_weights.size()}" |
| 137 | + ) |
| 138 | + |
| 139 | + # apply the causal_attention_mask first |
| 140 | + if causal_attention_mask is not None: |
| 141 | + if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): |
| 142 | + invalidInputError( |
| 143 | + False, |
| 144 | + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" |
| 145 | + f" {causal_attention_mask.size()}" |
| 146 | + ) |
| 147 | + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) \ |
| 148 | + + causal_attention_mask |
| 149 | + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
| 150 | + |
| 151 | + if attention_mask is not None: |
| 152 | + if attention_mask.size() != (bsz, 1, tgt_len, src_len): |
| 153 | + invalidInputError( |
| 154 | + False, |
| 155 | + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}," |
| 156 | + f" but is {attention_mask.size()}" |
| 157 | + ) |
| 158 | + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask |
| 159 | + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
| 160 | + |
| 161 | + attn_weights = nn.functional.softmax(attn_weights, dim=-1) |
| 162 | + |
| 163 | + if output_attentions: |
| 164 | + # this operation is a bit akward, but it's required to |
| 165 | + # make sure that attn_weights keeps its gradient. |
| 166 | + # In order to do so, attn_weights have to reshaped |
| 167 | + # twice and have to be reused in the following |
| 168 | + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
| 169 | + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) |
| 170 | + else: |
| 171 | + attn_weights_reshaped = None |
| 172 | + |
| 173 | + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
| 174 | + |
| 175 | + attn_output = torch.bmm(attn_probs, value_states) |
| 176 | + |
| 177 | + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): |
| 178 | + invalidInputError( |
| 179 | + False, |
| 180 | + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}," |
| 181 | + f" but is {attn_output.size()}" |
| 182 | + ) |
| 183 | + |
| 184 | + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) |
| 185 | + attn_output = attn_output.transpose(1, 2) |
| 186 | + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) |
| 187 | + |
| 188 | + attn_output = self.out_proj(attn_output) |
| 189 | + |
| 190 | + return attn_output, attn_weights_reshaped |
0 commit comments