Skip to content

Make key optional in ipex.llm.functional.rotary_embedding #821

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions intel_extension_for_pytorch/llm/functional/fusions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

def rotary_embedding(
query: torch.Tensor,
key: torch.Tensor,
key: Optional[torch.Tensor],
sin: torch.Tensor,
cos: torch.Tensor,
rotary_dim: int,
Expand All @@ -25,9 +25,10 @@ def rotary_embedding(
on the `query ` or `key` before their multi-head attention computation.

Args:
query, key (torch.Tensor) : inputs to be applied with position embeddings,
query (torch.Tensor), key (Optional[torch.Tensor]): inputs to be applied with position embeddings,
taking shape of [batch size, sequence length, num_head/num_kv_head, head_dim]
or [num_tokens, num_head/num_kv_head, head_dim] (as well as the output shape).
`key` may be `None`, e.g. in case of cross-layer KV sharing.
sin/cos (torch.Tensor): [num_tokens, rotary_dim] the sin/cos value tensor
generated to be applied on query/key.
rotary_ndims (int): the rotary dimension. e.g., 64 for GPTJ. head size for LLama.
Expand All @@ -42,7 +43,7 @@ def rotary_embedding(
The according position_ids for the input. The shape should be [batch size, sequence length].

Return
query, key (torch.Tensor): [batch size, sequence length, num_head/num_kv_head, head_dim]
query (torch.Tensor), key (Optional[torch.Tensor]): [batch size, sequence length, num_head/num_kv_head, head_dim]
or [num_tokens, num_head/num_kv_head, head_dim].

"""
Expand Down
14 changes: 9 additions & 5 deletions intel_extension_for_pytorch/llm/modules/mha_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,15 @@ class RotaryEmbedding(nn.Module):

[Direct function call] This module also provides a `.apply_function` function call
to be used on query and key at the same time without initializing the module
(assume rotary embedding sin/cos values are provided).
(assume rotary embedding sin/cos values are provided). `key` is optional for `.apply_function` call.

`apply_function()`

Args:
query, key (torch.Tensor) : inputs to be applied with position embeddings, taking shape of
query (torch.Tensor), key (Optional[torch.Tensor]) : inputs to be applied with position embeddings, taking shape of
[batch size, sequence length, num_head/num_kv_head, head_dim]
or [num_tokens, num_head/num_kv_head, head_dim] (as well as the output shape).
`key` may be None, e.g. in case of cross-layer KV sharing.
sin/cos (torch.Tensor): [num_tokens, rotary_dim] the sin/cos value tensor generated to be applied on query/key.
rotary_ndims (int): the rotary dimension. e.g., 64 for GPTJ. head size for LLama.
head_dim (int) : head dim from the input shape.
Expand All @@ -68,7 +69,7 @@ class RotaryEmbedding(nn.Module):
for the input. The shape should be [batch size, sequence length].

Return:
query, key (torch.Tensor): [batch size, sequence length, num_head/num_kv_head, head_dim]
query (torch.Tensor), key (Optional[torch.Tensor]): [batch size, sequence length, num_head/num_kv_head, head_dim]
or [num_tokens, num_head/num_kv_head, head_dim].

"""
Expand Down Expand Up @@ -137,14 +138,17 @@ def forward(
def apply_function(
cls,
query: torch.Tensor,
key: torch.Tensor,
key: Optional[torch.Tensor],
sin: torch.Tensor,
cos: torch.Tensor,
rotary_dim: int,
rotary_half: bool,
position_ids: torch.Tensor = None,
):
# query, key (in/out shape) torch.Tensor :
# query: torch.Tensor with in/out shape:
# 4D: [batch, seqlen, num_head/num_kv_head, head_dim]
# 3D: [num_tokens, num_head/num_kv_head, head_dim]
# key (optional) None or torch.Tensor with in/out shape:
# 4D: [batch, seqlen, num_head/num_kv_head, head_dim]
# 3D: [num_tokens, num_head/num_kv_head, head_dim]
# sin, cos: torch.Tensor [num_tokens, rotary_dim]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,23 +65,28 @@ def forward(
def rotary_embedding(
cls, query, key, sin, cos, rotary_dim, rotary_half, position_ids=None
):
# query, key (in/out shape) torch.Tensor :
# query: torch.Tensor with in/out shape:
# 4D: [bs, seqlen, num_head/num_kv_head, head_dim]
# 3D: [num_tokens, num_head/num_kv_head, head_dim]
# key (optional) None or torch.Tensor with in/out shape:
# 4D: [bs, seqlen, num_head/num_kv_head, head_dim]
# 3D: [num_tokens, num_head/num_kv_head, head_dim]
# sin, cos: torch.Tensor [num_tokens, rotary_dim]
# position_ids (optional): torch.Tensor [bs, seqlen]
head_dim = query.size(-1)
num_head = query.size(-2)
num_kv_head = key.size(-2)
num_kv_head = key.size(-2) if key is not None else 0
input_3d = False
assert (
key.dim() == query.dim() and query.dim() == 3 or query.dim() == 4
(key is None or key.dim() == query.dim())
and query.dim() == 3
or query.dim() == 4
), "rotary embedding query/key dim == 3 or 4"

if query.dim() == 3:
input_3d = True
query_ = query.unsqueeze(0)
key_ = key.unsqueeze(0)
key_ = key.unsqueeze(0) if key is not None else None
else:
query_ = query
key_ = key
Expand Down Expand Up @@ -124,21 +129,26 @@ def rotary_embedding(
rotary_dim,
)

key_, _, _ = torch.ops.torch_ipex.rotary_position_embedding(
key_,
sin_cos,
position_ids,
num_kv_head,
head_dim,
offset,
rotary_dim,
)
if key is not None:
key_, _, _ = torch.ops.torch_ipex.rotary_position_embedding(
key_,
sin_cos,
position_ids,
num_kv_head,
head_dim,
offset,
rotary_dim,
)
if input_3d:
query_ = query_.view([-1, num_head, head_dim])
key_ = key_.view([-1, num_kv_head, head_dim])
if key_ is not None:
key_ = key_.view([-1, num_kv_head, head_dim])
# keep the inplace context as used in TGI
query.copy_(query_)
key.copy_(key_)

if key is not None:
key.copy_(key_)

return query, key


Expand Down
12 changes: 8 additions & 4 deletions tests/cpu/test_ipex_llm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,23 +884,27 @@ def test_rotary_embedding_tgi(self):
(1, 32, 128),
(32, 32, 128),
]
for size in test_tensor_size:
for size, use_key in itertools.product(test_tensor_size, [True, False]):
q = torch.randn(size).float()
k = torch.randn(size).float()
k = torch.randn(size).float() if use_key else None
rotary_dim = size[-1]
seqlen = size[0]
position_ids = torch.arange(size[0])
sin, cos = get_sin_cos(position_ids, rotary_dim, 10000, seqlen, q.dtype)

ref_q = apply(q, cos, sin)
ref_k = apply(k, cos, sin)
ref_k = apply(k, cos, sin) if use_key else None

ipex_q, ipex_k = ipex.llm.functional.rotary_embedding(
q, k, sin, cos, rotary_dim, True
)

self.assertEqual(ipex_q, ref_q)
self.assertEqual(ref_k, ipex_k)
if use_key:
self.assertEqual(ref_k, ipex_k)
else:
self.assertIsNone(ipex_k)
self.assertIsNone(ref_k)

def test_add_layernorm(self):
for add_back in [True, False]:
Expand Down
Loading