Skip to content

Use torch in get_2d_sincos_pos_embed and get_3d_sincos_pos_embed #10156

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 13, 2024

Conversation

hlky
Copy link
Contributor

@hlky hlky commented Dec 9, 2024

What does this PR do?

Refactors get_2d_sincos_pos_embed and get_3d_sincos_pos_embed to use torch instead of numpy, and adds device argument so that tensors can be created on e.g. cuda.

Usage of get_2d_sincos_pos_embed and get_3d_sincos_pos_embed is updated to pass device where applicable (we don't specify device during initialization so we don't pass device to the function when used from __init__, the device from weights would just be cpu)

torch and numpy versions match numerically.

Reproduction `get_2d_sincos_pos_embed`
from diffusers.models.embeddings import get_2d_sincos_pos_embed
import torch
from typing import Optional


def get_2d_sincos_pos_embed_torch(
  embed_dim,
  grid_size,
  cls_token=False,
  extra_tokens=0,
  interpolation_scale=1.0,
  base_size=16,
  device: Optional[torch.device] = None,
):
  """
  Creates 2D sinusoidal positional embeddings.

  Args:
      embed_dim (`int`):
          The embedding dimension.
      grid_size (`int`):
          The size of the grid height and width.
      cls_token (`bool`, defaults to `False`):
          Whether or not to add a classification token.
      extra_tokens (`int`, defaults to `0`):
          The number of extra tokens to add.
      interpolation_scale (`float`, defaults to `1.0`):
          The scale of the interpolation.

  Returns:
      pos_embed (`torch.Tensor`):
          Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
          embed_dim]` if using cls_token
  """
  if isinstance(grid_size, int):
      grid_size = (grid_size, grid_size)

  grid_h = (
      torch.arange(grid_size[0], device=device, dtype=torch.float32)
      / (grid_size[0] / base_size)
      / interpolation_scale
  )
  grid_w = (
      torch.arange(grid_size[1], device=device, dtype=torch.float32)
      / (grid_size[1] / base_size)
      / interpolation_scale
  )
  grid = torch.meshgrid(grid_w, grid_h, indexing="xy")  # here w goes first
  grid = torch.stack(grid, dim=0)

  grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
  pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid)
  if cls_token and extra_tokens > 0:
      pos_embed = torch.concat(
          [torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0
      )
  return pos_embed


def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
  r"""
  This function generates 2D sinusoidal positional embeddings from a grid.

  Args:
      embed_dim (`int`): The embedding dimension.
      grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`.

  Returns:
      `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
  """
  if embed_dim % 2 != 0:
      raise ValueError("embed_dim must be divisible by 2")

  # use half of dimensions to encode grid_h
  emb_h = get_1d_sincos_pos_embed_from_grid_torch(
      embed_dim // 2, grid[0]
  )  # (H*W, D/2)
  emb_w = get_1d_sincos_pos_embed_from_grid_torch(
      embed_dim // 2, grid[1]
  )  # (H*W, D/2)

  emb = torch.concat([emb_h, emb_w], dim=1)  # (H*W, D)
  return emb


def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
  """
  This function generates 1D positional embeddings from a grid.

  Args:
      embed_dim (`int`): The embedding dimension `D`
      pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`

  Returns:
      `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
  """
  if embed_dim % 2 != 0:
      raise ValueError("embed_dim must be divisible by 2")

  omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
  omega /= embed_dim / 2.0
  omega = 1.0 / 10000**omega  # (D/2,)

  pos = pos.reshape(-1)  # (M,)
  out = torch.outer(pos, omega)  # (M, D/2), outer product

  emb_sin = torch.sin(out)  # (M, D/2)
  emb_cos = torch.cos(out)  # (M, D/2)

  emb = torch.concat([emb_sin, emb_cos], dim=1)  # (M, D)
  return emb


hidden_size = 2560
pos_embed_max_size = 128

pos_embed_np = get_2d_sincos_pos_embed(
  hidden_size, pos_embed_max_size, base_size=pos_embed_max_size
)

pos_embed = get_2d_sincos_pos_embed_torch(
  hidden_size, pos_embed_max_size, base_size=pos_embed_max_size
)

torch.testing.assert_close(pos_embed, torch.from_numpy(pos_embed_np))
Reproduction `get_3d_sincos_pos_embed`
import torch
from typing import Optional, Tuple, Union
import numpy as np

embed_dim = 1920
sample_height = 60
sample_width = 90
sample_frames = 49
patch_size = 2
temporal_compression_ratio = 4
spatial_interpolation_scale = 1.875
temporal_interpolation_scale = 1.0
post_patch_height = sample_height // patch_size
post_patch_width = sample_width // patch_size
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
num_patches = post_patch_height * post_patch_width * post_time_compression_frames


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
  r"""
  This function generates 2D sinusoidal positional embeddings from a grid.

  Args:
      embed_dim (`int`): The embedding dimension.
      grid (`np.ndarray`): Grid of positions with shape `(H * W,)`.

  Returns:
      `np.ndarray`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
  """
  if embed_dim % 2 != 0:
      raise ValueError("embed_dim must be divisible by 2")

  # use half of dimensions to encode grid_h
  emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
  emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

  emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
  return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
  """
  This function generates 1D positional embeddings from a grid.

  Args:
      embed_dim (`int`): The embedding dimension `D`
      pos (`numpy.ndarray`): 1D tensor of positions with shape `(M,)`

  Returns:
      `numpy.ndarray`: Sinusoidal positional embeddings of shape `(M, D)`.
  """
  if embed_dim % 2 != 0:
      raise ValueError("embed_dim must be divisible by 2")

  omega = np.arange(embed_dim // 2, dtype=np.float64)
  omega /= embed_dim / 2.0
  omega = 1.0 / 10000**omega  # (D/2,)

  pos = pos.reshape(-1)  # (M,)
  out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product

  emb_sin = np.sin(out)  # (M, D/2)
  emb_cos = np.cos(out)  # (M, D/2)

  emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
  return emb


def get_3d_sincos_pos_embed(
  embed_dim: int,
  spatial_size: Union[int, Tuple[int, int]],
  temporal_size: int,
  spatial_interpolation_scale: float = 1.0,
  temporal_interpolation_scale: float = 1.0,
) -> np.ndarray:
  r"""
  Creates 3D sinusoidal positional embeddings.

  Args:
      embed_dim (`int`):
          The embedding dimension of inputs. It must be divisible by 16.
      spatial_size (`int` or `Tuple[int, int]`):
          The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
          spatial dimensions (height and width).
      temporal_size (`int`):
          The temporal dimension of postional embeddings (number of frames).
      spatial_interpolation_scale (`float`, defaults to 1.0):
          Scale factor for spatial grid interpolation.
      temporal_interpolation_scale (`float`, defaults to 1.0):
          Scale factor for temporal grid interpolation.

  Returns:
      `np.ndarray`:
          The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
          embed_dim]`.
  """
  if embed_dim % 4 != 0:
      raise ValueError("`embed_dim` must be divisible by 4")
  if isinstance(spatial_size, int):
      spatial_size = (spatial_size, spatial_size)

  embed_dim_spatial = 3 * embed_dim // 4
  embed_dim_temporal = embed_dim // 4

  # 1. Spatial
  grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale
  grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale
  grid = np.meshgrid(grid_w, grid_h)  # here w goes first
  grid = np.stack(grid, axis=0)

  grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
  pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)

  # 2. Temporal
  grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale
  pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)

  # 3. Concat
  pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
  pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0)  # [T, H*W, D // 4 * 3]

  pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
  pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1)  # [T, H*W, D // 4]

  pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)  # [T, H*W, D]
  return pos_embed


pos_embedding_np = get_3d_sincos_pos_embed(
  embed_dim,
  (post_patch_width, post_patch_height),
  post_time_compression_frames,
  spatial_interpolation_scale,
  temporal_interpolation_scale,
)


def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
  r"""
  This function generates 2D sinusoidal positional embeddings from a grid.

  Args:
      embed_dim (`int`): The embedding dimension.
      grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`.

  Returns:
      `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
  """
  if embed_dim % 2 != 0:
      raise ValueError("embed_dim must be divisible by 2")

  # use half of dimensions to encode grid_h
  emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0])  # (H*W, D/2)
  emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1])  # (H*W, D/2)

  emb = torch.concat([emb_h, emb_w], dim=1)  # (H*W, D)
  return emb


def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
  """
  This function generates 1D positional embeddings from a grid.

  Args:
      embed_dim (`int`): The embedding dimension `D`
      pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`

  Returns:
      `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
  """
  if embed_dim % 2 != 0:
      raise ValueError("embed_dim must be divisible by 2")

  omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
  omega /= embed_dim / 2.0
  omega = 1.0 / 10000**omega  # (D/2,)

  pos = pos.reshape(-1)  # (M,)
  out = torch.outer(pos, omega)  # (M, D/2), outer product

  emb_sin = torch.sin(out)  # (M, D/2)
  emb_cos = torch.cos(out)  # (M, D/2)

  emb = torch.concat([emb_sin, emb_cos], dim=1)  # (M, D)
  return emb

def get_3d_sincos_pos_embed_torch(
  embed_dim: int,
  spatial_size: Union[int, Tuple[int, int]],
  temporal_size: int,
  spatial_interpolation_scale: float = 1.0,
  temporal_interpolation_scale: float = 1.0,
  device: Optional[torch.device] = None,
) -> torch.Tensor:
  r"""
  Creates 3D sinusoidal positional embeddings.

  Args:
      embed_dim (`int`):
          The embedding dimension of inputs. It must be divisible by 16.
      spatial_size (`int` or `Tuple[int, int]`):
          The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
          spatial dimensions (height and width).
      temporal_size (`int`):
          The temporal dimension of postional embeddings (number of frames).
      spatial_interpolation_scale (`float`, defaults to 1.0):
          Scale factor for spatial grid interpolation.
      temporal_interpolation_scale (`float`, defaults to 1.0):
          Scale factor for temporal grid interpolation.

  Returns:
      `torch.Tensor`:
          The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
          embed_dim]`.
  """
  if embed_dim % 4 != 0:
      raise ValueError("`embed_dim` must be divisible by 4")
  if isinstance(spatial_size, int):
      spatial_size = (spatial_size, spatial_size)

  embed_dim_spatial = 3 * embed_dim // 4
  embed_dim_temporal = embed_dim // 4

  # 1. Spatial
  grid_h = torch.arange(spatial_size[1], device=device, dtype=torch.float32) / spatial_interpolation_scale
  grid_w = torch.arange(spatial_size[0], device=device, dtype=torch.float32) / spatial_interpolation_scale
  grid = torch.meshgrid(grid_w, grid_h, indexing="xy")  # here w goes first
  grid = torch.stack(grid, dim=0)

  grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
  pos_embed_spatial = get_2d_sincos_pos_embed_from_grid_torch(embed_dim_spatial, grid)

  # 2. Temporal
  grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32) / temporal_interpolation_scale
  pos_embed_temporal = get_1d_sincos_pos_embed_from_grid_torch(embed_dim_temporal, grid_t)

  # 3. Concat
  pos_embed_spatial = pos_embed_spatial[None, :, :]
  pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0)  # [T, H*W, D // 4 * 3]

  pos_embed_temporal = pos_embed_temporal[:, None, :]
  pos_embed_temporal = pos_embed_temporal.repeat_interleave(spatial_size[0] * spatial_size[1], dim=1)  # [T, H*W, D // 4]

  pos_embed = torch.concat([pos_embed_temporal, pos_embed_spatial], dim=-1)  # [T, H*W, D]
  return pos_embed

pos_embedding = get_3d_sincos_pos_embed_torch(
  embed_dim,
  (post_patch_width, post_patch_height),
  post_time_compression_frames,
  spatial_interpolation_scale,
  temporal_interpolation_scale,
)

torch.testing.assert_close(pos_embedding, torch.from_numpy(pos_embedding_np))

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@hlky
Copy link
Contributor Author

hlky commented Dec 9, 2024

Draft because I noticed get_3d_sincos_pos_embed needs refactoring at the same time as this.

@a-r-r-o-w
Copy link
Member

@hlky, thanks for taking this up! if it helps, there are some conversions in #9654 for each of the functions

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@hlky hlky changed the title Use torch in get_2d_sincos_pos_embed Use torch in get_2d_sincos_pos_embed and get_3d_sincos_pos_embed Dec 9, 2024
@hlky hlky marked this pull request as ready for review December 9, 2024 15:05
@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Dec 9, 2024

ohh thanks!!
I think we also need to run slow test of all models that used these two things :)

@hlky
Copy link
Contributor Author

hlky commented Dec 10, 2024

Downstream usage (there's probably more)

This will break https://github.com/kijai/ComfyUI-CogVideoXWrapper/blob/795f8b05659dfa5ec6b216fb698bcca6fda34fdb/embeddings.py#L64-L71 because we return torch.Tensor instead and torch.from_numpy requires np.ndarray. Repo uses latest Diffusers though not main so won't be immediate.

@yiyixuxu
Copy link
Collaborator

just to be extra safe, maybe we can deprecate it instead

  1. move the current implementation into get_2d_sincos_pos_embed_np and get_3d_sincos_pos_embed_np
  2. add a output_type argument and default to np
  3. if np dispatch *_np methods, and send a deprecate message
  4. change the output_type="pt" in diffusers code base

let me know what you think!

@hlky
Copy link
Contributor Author

hlky commented Dec 10, 2024

Yes makes sense to deprecate it, by 0.33?

@yiyixuxu
Copy link
Collaborator

sure!

@hlky
Copy link
Contributor Author

hlky commented Dec 11, 2024

Added here c5bd771

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! I left one comment
code change looks good to me! does this only affect latte? can we run the doc string example for affected models to make sure no output change before merge

@@ -141,6 +156,66 @@ def get_3d_sincos_pos_embed(
return pos_embed


def get_3d_sincos_pos_embed_np(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make it a private method _get_3d_sincos_pos_embed_np and deprecate here too in case anymore wants to use it directly

@hlky
Copy link
Contributor Author

hlky commented Dec 13, 2024

Affected models:

get_3d_sincos_pos_embed CogVideoXPatchEmbed

get_2d_sincos_pos_embed_from_grid used by get_3d_sincos_pos_embed, get_2d_sincos_pos_embed

get_1d_sincos_pos_embed_from_grid used by get_3d_sincos_pos_embed, get_2d_sincos_pos_embed_from_grid, LatteTransformer3DModel

get_2d_sincos_pos_embed used by get_3d_sincos_pos_embed, CogView3PlusPatchEmbed, PatchEmbed (in embeddings), PatchEmbed (in modeling_uvit, duplicate but older version?)

PatchEmbed used by HunyuanDiT2DControlNetModel, SD3ControlNetModel, DiTTransformer2DModel, HunyuanDiT2DModel, LatteTransformer3DModel, PixArtTransformer2DModel, Transformer2DModel (deprecated? it's split now), AllegroTransformer3DModel, MochiTransformer3DModel, SD3Transformer2DModel, UniDiffuserModel/UTransformer2DModel (uvit)

I'll test docstring examples for each model, might take some time as there's a few large/slow models.

@hlky
Copy link
Contributor Author

hlky commented Dec 13, 2024

CogView3PlusPipeline

Hash matched.

Original Branch
CogView3PlusPipeline CogView3PlusPipeline

DiTPipeline

Hash matched.

Original Branch
DiTPipeline DiTPipeline

HunyuanDiTControlNet

Different hash, Branch slightly larger file size.

Original Branch
HunyuanDiTControlNet HunyuanDiTControlNet

HunyuanDiTPipeline

Hash matched.

Original Branch
HunyuanDiTPipeline HunyuanDiTPipeline

latte

Hash matched.

Original Branch
latte latte

mochi

Hash matched.

Original

mochi.mp4

Branch

mochi.mp4

PixArtSigmaPipeline

Hash matched.

Original Branch
PixArtSigmaPipeline PixArtSigmaPipeline

sd3

Hash matched.

Original Branch
sd3 sd3

StableDiffusion3ControlNet

Different hash, Branch slightly larger file size.

Original Branch
StableDiffusion3ControlNet StableDiffusion3ControlNet

No visible difference I can see in the two ControlNet cases where hashes aren't a match.

Allegro needs separate testing as it's very slow model.

@hlky
Copy link
Contributor Author

hlky commented Dec 13, 2024

Allegro

Hash matched.

AllegroPipeline.2.mp4
AllegroPipeline.3.mp4

I just let it run but we should follow this #10212 (comment) in the future.

@yiyixuxu
Copy link
Collaborator

wow!!! thank you @hlky

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you!

@yiyixuxu yiyixuxu merged commit 6324340 into huggingface:main Dec 13, 2024
12 checks passed
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
#10156)

* Use torch in get_2d_sincos_pos_embed

* Use torch in get_3d_sincos_pos_embed

* get_1d_sincos_pos_embed_from_grid in LatteTransformer3DModel

* deprecate

* move deprecate, make private
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants