File tree Expand file tree Collapse file tree 2 files changed +5
-6
lines changed Expand file tree Collapse file tree 2 files changed +5
-6
lines changed Original file line number Diff line number Diff line change 1
- import math
2
1
from dataclasses import dataclass
3
2
from typing import Dict , Optional , Union
4
3
@@ -249,11 +248,7 @@ def forward(
249
248
# but time_embedding might be fp16, so we need to cast here.
250
249
timesteps_projected = timesteps_projected .to (dtype = self .dtype )
251
250
time_embeddings = self .time_embedding (timesteps_projected )
252
-
253
- # Rescale the features to have unit variance
254
- # YiYi TO-DO: It was normalized before during encode_prompt step, move this step to pipeline
255
- if self .clip_mean is None :
256
- proj_embedding = math .sqrt (proj_embedding .shape [1 ]) * proj_embedding
251
+
257
252
proj_embeddings = self .embedding_proj (proj_embedding )
258
253
if self .encoder_hidden_states_proj is not None and encoder_hidden_states is not None :
259
254
encoder_hidden_states = self .encoder_hidden_states_proj (encoder_hidden_states )
Original file line number Diff line number Diff line change 12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import math
15
16
from dataclasses import dataclass
16
17
from typing import List , Optional , Union
17
18
@@ -242,6 +243,9 @@ def _encode_prompt(
242
243
# Here we concatenate the unconditional and text embeddings into a single batch
243
244
# to avoid doing two forward passes
244
245
prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ])
246
+
247
+ # Rescale the features to have unit variance (this step is taken from the original repo)
248
+ prompt_embeds = math .sqrt (prompt_embeds .shape [1 ]) * prompt_embeds
245
249
246
250
return prompt_embeds
247
251
You can’t perform that action at this time.
0 commit comments