Skip to content

Commit 5b5a8e6

Browse files
author
yiyixuxu
committed
move the rescale prompt_embeds from prior_transformer to pipeline
1 parent 6ec68ee commit 5b5a8e6

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

src/diffusers/models/prior_transformer.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import math
21
from dataclasses import dataclass
32
from typing import Dict, Optional, Union
43

@@ -249,11 +248,7 @@ def forward(
249248
# but time_embedding might be fp16, so we need to cast here.
250249
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
251250
time_embeddings = self.time_embedding(timesteps_projected)
252-
253-
# Rescale the features to have unit variance
254-
# YiYi TO-DO: It was normalized before during encode_prompt step, move this step to pipeline
255-
if self.clip_mean is None:
256-
proj_embedding = math.sqrt(proj_embedding.shape[1]) * proj_embedding
251+
257252
proj_embeddings = self.embedding_proj(proj_embedding)
258253
if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
259254
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)

src/diffusers/pipelines/shap_e/pipeline_shap_e.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import math
1516
from dataclasses import dataclass
1617
from typing import List, Optional, Union
1718

@@ -242,6 +243,9 @@ def _encode_prompt(
242243
# Here we concatenate the unconditional and text embeddings into a single batch
243244
# to avoid doing two forward passes
244245
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
246+
247+
# Rescale the features to have unit variance (this step is taken from the original repo)
248+
prompt_embeds = math.sqrt(prompt_embeds.shape[1]) * prompt_embeds
245249

246250
return prompt_embeds
247251

0 commit comments

Comments
 (0)