Skip to content

Commit 10c9c85

Browse files
committed
export llama with lora
1 parent 753a88e commit 10c9c85

File tree

5 files changed

+112
-5
lines changed

5 files changed

+112
-5
lines changed

Diff for: examples/models/llama/attention.py

+59-2
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,48 @@ def forward(
160160
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
161161

162162

163+
class LoRALinear(nn.Module):
164+
def __init__(
165+
self,
166+
in_dim: int,
167+
out_dim: int,
168+
rank: int,
169+
alpha: float,
170+
dropout: float = 0.0,
171+
use_bias: bool = False,
172+
):
173+
super().__init__()
174+
self.in_dim = in_dim
175+
self.out_dim = out_dim
176+
self.rank = rank
177+
self.alpha = alpha
178+
self.use_bias = use_bias
179+
self.dropout = dropout
180+
181+
# Setup weight and bias
182+
# self.wq = nn.Linear(
183+
# self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
184+
# )
185+
linear_q = nn.Linear(in_dim, out_dim, bias=use_bias)
186+
weight = linear_q.weight
187+
bias = linear_q.bias if self.use_bias else None
188+
self.register_parameter("weight", nn.Parameter(weight))
189+
self.register_parameter(
190+
"bias", nn.Parameter(bias) if bias is not None else None
191+
)
192+
193+
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
194+
self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False)
195+
self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False)
196+
197+
def forward(self, x: torch.Tensor) -> torch.Tensor:
198+
out = torch.nn.functional.linear(x, self.weight, self.bias)
199+
lora_out = self.lora_a(self.dropout(x))
200+
lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)
201+
202+
return out + lora_out
203+
204+
163205
@register_attention("mha")
164206
class AttentionMHA(Attention):
165207
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
@@ -185,9 +227,19 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
185227
self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps)
186228
self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps)
187229

188-
self.wq = nn.Linear(
189-
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
230+
# self.wq = nn.Linear(
231+
# self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
232+
# )
233+
self.wq = LoRALinear(
234+
in_dim=self.dim,
235+
out_dim=self.n_heads * self.head_dim,
236+
rank=8,
237+
alpha=16.0,
238+
dropout=0.0,
239+
use_bias=self.attention_qkv_bias,
190240
)
241+
242+
# breakpoint()
191243
self.wk = nn.Linear(
192244
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
193245
)
@@ -238,6 +290,10 @@ def forward(
238290

239291
# QKV
240292
q, k, v = self.wq(x), self.wk(x), self.wv(x)
293+
294+
# q_per_kv = self.num_heads // self.num_kv_heads
295+
# q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim)
296+
241297
# We need view_copy elimination
242298
q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim)
243299
k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
@@ -268,6 +324,7 @@ def forward(
268324

269325
mask = self.mask[:seqlen, :seqlen]
270326

327+
# Somehow, kv become floats.
271328
output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
272329

273330
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

Diff for: examples/models/llama/export_llama_lib.py

+34-3
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@
2626

2727
from executorch.backends.vulkan._passes.remove_asserts import remove_asserts
2828
from executorch.devtools.backend_debug import print_delegation_info
29-
3029
from executorch.devtools.etrecord import generate_etrecord
30+
31+
from executorch.examples.models.llama.attention import ForwardOptions
3132
from executorch.examples.models.llama.hf_download import (
3233
download_and_convert_hf_checkpoint,
3334
)
@@ -455,6 +456,18 @@ def build_args_parser() -> argparse.ArgumentParser:
455456
help="Whether the checkpoin is pre-quantized with QAT or not.",
456457
)
457458

459+
parser.add_argument(
460+
"--adapter",
461+
default=None,
462+
help="Adapter path",
463+
)
464+
465+
parser.add_argument(
466+
"--adapter_config",
467+
default=None,
468+
help="Adapter config path",
469+
)
470+
458471
parser.add_argument(
459472
"-lora",
460473
"--use_lora",
@@ -591,6 +604,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
591604
checkpoint_dir = (
592605
canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None
593606
)
607+
adapter_path = canonical_path(args.adapter) if args.adapter else None
594608
params_path = canonical_path(args.params) if args.params else None
595609
output_dir_path = canonical_path(args.output_dir, dir=True)
596610
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
@@ -602,6 +616,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
602616
args.model,
603617
checkpoint=checkpoint_path,
604618
checkpoint_dir=checkpoint_dir,
619+
adapter=adapter_path,
605620
params_path=params_path,
606621
use_kv_cache=args.use_kv_cache,
607622
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
@@ -641,8 +656,8 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
641656
logging.warning(
642657
f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}."
643658
)
644-
645659
edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())
660+
breakpoint()
646661

647662
# We want to quantize (in the source transforms) the weights of the model
648663
# in the checkpoint dtype.
@@ -656,10 +671,12 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
656671
)
657672
)
658673

674+
breakpoint()
675+
659676
return edge_manager
660677

661678

662-
def get_quantizer_and_quant_params(args):
679+
def get_quantizer_and_quant_params(args):c
663680
pt2e_quant_params = get_pt2e_quantization_params(
664681
args.pt2e_quantize, args.quantization_mode
665682
)
@@ -948,6 +965,11 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
948965
args,
949966
)
950967
else:
968+
from executorch.examples.models.llama.attention import ForwardOptions
969+
eg = torch.tensor([[2, 3, 4]], dtype=torch.int64)
970+
fw = ForwardOptions(input_pos=torch.tensor([0], dtype=torch.long))
971+
breakpoint()
972+
951973
builder = _to_edge_and_lower_llama(
952974
builder_exported,
953975
modelname,
@@ -958,6 +980,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
958980
args,
959981
)
960982

983+
breakpoint()
961984
if args.profile_memory:
962985
generate_memory_trace(builder.export_program, "memory_profile.json")
963986

@@ -1020,6 +1043,7 @@ def _load_llama_model(
10201043
*,
10211044
checkpoint: Optional[str] = None,
10221045
checkpoint_dir: Optional[str] = None,
1046+
adapter: Optional[str] = None,
10231047
params_path: Optional[str] = None,
10241048
use_kv_cache: bool = False,
10251049
use_sdpa_with_kv_cache: bool = False,
@@ -1067,6 +1091,7 @@ def _load_llama_model(
10671091
model_class_name,
10681092
checkpoint=checkpoint,
10691093
checkpoint_dir=checkpoint_dir,
1094+
adapter=adapter,
10701095
params=params_path,
10711096
use_kv_cache=use_kv_cache,
10721097
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
@@ -1081,6 +1106,11 @@ def _load_llama_model(
10811106
args=args,
10821107
)
10831108
)
1109+
eg = torch.tensor([[13347]], dtype=torch.long)
1110+
ip = torch.tensor([0], dtype=torch.long)
1111+
fw = ForwardOptions(input_pos=ip)
1112+
# breakpoint()
1113+
# model.forward(eg, fw)
10841114

10851115
return LLMEdgeManager(
10861116
model=model,
@@ -1206,6 +1236,7 @@ def _get_source_transforms( # noqa
12061236
transforms.append(materialze_broadcast_of_rope_freq_cis)
12071237

12081238
if args.use_sdpa_with_kv_cache:
1239+
# here.
12091240
transforms.append(replace_kv_cache_with_custom_kv_cache)
12101241
transforms.append(replace_sdpa_with_custom_op)
12111242

Diff for: examples/models/llama/model.py

+17
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
from executorch.examples.models.llama.model_args import ModelArgs
2121

22+
from torchtune.models import convert_weights
23+
2224
try:
2325
from .fairseq2 import convert_to_llama_checkpoint
2426

@@ -45,6 +47,9 @@ def __init__(self, **kwargs):
4547
# Params file.
4648
params_path = kwargs.get("params", None)
4749

50+
# Adapter file.
51+
adapter_path = kwargs.get("adapter", None)
52+
4853
self.use_kv_cache = kwargs.get("use_kv_cache", False)
4954
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
5055
self.generate_full_logits = kwargs.get("generate_full_logits", False)
@@ -96,6 +101,15 @@ def __init__(self, **kwargs):
96101
elif checkpoint_path:
97102
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
98103

104+
# Load adapter.
105+
if adapter_path:
106+
print("Loading adapter from: ", adapter_path)
107+
adapter = torch.load(adapter_path, map_location=device, mmap=True)
108+
adapter = convert_weights.tune_to_meta(adapter)
109+
# Convert from tune to meta.
110+
# breakpoint()
111+
checkpoint.update(adapter)
112+
99113
# If given checkpoint is fairseq, convert to llama checkpoint.
100114
fairseq2_checkpoint = kwargs.get("fairseq2", False)
101115
if fairseq2_checkpoint:
@@ -174,8 +188,10 @@ def __init__(self, **kwargs):
174188
with torch.device("meta"):
175189
# Model itself is loaded in default dtype, fp32.
176190
self.model_ = Transformer(model_args)
191+
177192
# Get checkpoint dtype.
178193
if checkpoint:
194+
# breakpoint()
179195
self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)
180196
else:
181197
self.model_.checkpoint_dtype = torch.float32
@@ -252,6 +268,7 @@ def __init__(self, **kwargs):
252268
# by default initialized to fp32. This is fine because every other supported type
253269
# losslessly converts to fp32, so we don't lose precision here.
254270
if checkpoint:
271+
# breakpoint()
255272
missing, unexpected = self.model_.load_state_dict(
256273
checkpoint,
257274
strict=False,

Diff for: examples/models/llama/source_transformation/quantized_kv_cache.py

+1
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ def replace_kv_cache_with_custom_kv_cache(module):
283283
def _replace_kv_cache_with_custom_kv_cache(module):
284284
for name, child in module.named_children():
285285
if isinstance(child, KVCache):
286+
# breakpoint()
286287
cache_shape = child.k_cache.shape
287288
cache_dtype = child.k_cache.dtype
288289
max_batch_size, n_heads, max_context_length, head_dim = cache_shape

Diff for: extension/llm/export/builder.py

+1
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def source_transform(
165165
list of source transforms.
166166
"""
167167
for transform in transforms:
168+
breakpoint()
168169
self.model = transform(self.model)
169170
self.applied_source_transforms.extend(transforms)
170171

0 commit comments

Comments
 (0)