Skip to content

Commit fbb94d9

Browse files
Fix loading more recent VaeKl checkpoints (#305)
1 parent c4316f6 commit fbb94d9

File tree

5 files changed

+116
-78
lines changed

5 files changed

+116
-78
lines changed

lib/bumblebee/conversion/pytorch.ex

Lines changed: 62 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -106,35 +106,42 @@ defmodule Bumblebee.Conversion.PyTorch do
106106
Enum.reduce(layer.parameters, {[], diff}, fn param, {params, diff} ->
107107
param_expr = params_expr[layer_name][param.name]
108108

109-
{sources, source_fun} =
109+
{sources, builder_fun} =
110110
case params_source do
111-
%{} = layer_params_mapping ->
112-
if info = layer_params_mapping[param.name] do
113-
info
111+
%{} = param_builders ->
112+
if param_builder = param_builders[param.name] do
113+
param_builder
114114
else
115-
raise "no matching mapping found for parameter #{inspect(param.name)} in #{inspect(layer_params_mapping)}"
115+
raise "no matching mapping found for parameter #{inspect(param.name)} in #{inspect(param_builders)}"
116116
end
117117

118-
source_layer_name when is_binary(source_layer_name) ->
119-
default_layer_param_source(layer, param.name, source_layer_name)
118+
source_layer_name
119+
when is_binary(source_layer_name) or
120+
is_list(source_layer_name) ->
121+
default_layer_param_builder(layer, param.name, source_layer_name)
120122
end
121123

122124
{all_sources_found?, source_values, source_keys} =
123-
for {source_layer_name, source_param_name} <- sources, reduce: {true, [], []} do
125+
for source <- sources, reduce: {true, [], []} do
124126
{all_found?, values, keys} ->
125-
source_param_names = List.wrap(source_param_name)
126-
127-
case lookup_param(pytorch_state, source_layer_name, source_param_names) do
128-
{:ok, value, key} -> {all_found?, [value | values], [key | keys]}
129-
:error -> {false, values, keys}
127+
# Source can be either {layer_name, param_name}, or
128+
# a list of these, to find any match
129+
source
130+
|> List.wrap()
131+
|> Enum.find_value(fn {source_layer_name, source_param_name} ->
132+
lookup_param(pytorch_state, source_layer_name, source_param_name)
133+
end)
134+
|> case do
135+
{value, key} -> {all_found?, [value | values], [key | keys]}
136+
nil -> {false, values, keys}
130137
end
131138
end
132139

133140
diff = prepend(diff, :used_keys, source_keys)
134141

135142
{value, diff} =
136143
if all_sources_found? do
137-
value = source_fun.(Enum.reverse(source_values))
144+
value = builder_fun.(Enum.reverse(source_values))
138145

139146
case verify_param_shape(param_expr, value) do
140147
:ok ->
@@ -186,11 +193,15 @@ defmodule Bumblebee.Conversion.PyTorch do
186193

187194
source_templates =
188195
Enum.flat_map(params_mapping, fn
189-
{_target_template, %{} = params_source} ->
190-
for {_target_param_name, {sources, _source_fun}} <- params_source,
191-
{source_template, _source_param_name} <- sources,
196+
{_target_template, %{} = param_builders} ->
197+
for {_target_param_name, {sources, _builder_fun}} <- param_builders,
198+
ref_or_refs <- sources,
199+
{source_template, _source_param_name} <- List.wrap(ref_or_refs),
192200
do: source_template
193201

202+
{_target_template, source_templates} when is_list(source_templates) ->
203+
source_templates
204+
194205
{_target_template, source_template} when is_binary(source_template) ->
195206
[source_template]
196207
end)
@@ -339,17 +350,17 @@ defmodule Bumblebee.Conversion.PyTorch do
339350

340351
defp format_list(items), do: Enum.map_join(items, "\n", &(" * " <> &1))
341352

342-
defp default_layer_param_source(%{op_name: :dense}, "kernel", layer_name) do
343-
{[{layer_name, "weight"}],
353+
defp default_layer_param_builder(%{op_name: :dense}, "kernel", layer_name) do
354+
{[param_refs(layer_name, "weight")],
344355
fn [kernel] ->
345356
[out_features, in_features] = Nx.axes(kernel)
346357
Nx.transpose(kernel, axes: [in_features, out_features])
347358
end}
348359
end
349360

350-
defp default_layer_param_source(layer, "kernel", layer_name)
361+
defp default_layer_param_builder(layer, "kernel", layer_name)
351362
when layer.op_name in [:conv, :depthwise_conv] do
352-
{[{layer_name, "weight"}],
363+
{[param_refs(layer_name, "weight")],
353364
fn [kernel] ->
354365
[out_channels, in_channels | kernel_spatials] = Nx.axes(kernel)
355366

@@ -360,8 +371,8 @@ defmodule Bumblebee.Conversion.PyTorch do
360371
end}
361372
end
362373

363-
defp default_layer_param_source(%{op_name: :conv_transpose} = layer, "kernel", layer_name) do
364-
{[{layer_name, "weight"}],
374+
defp default_layer_param_builder(%{op_name: :conv_transpose} = layer, "kernel", layer_name) do
375+
{[param_refs(layer_name, "weight")],
365376
fn [kernel] ->
366377
[in_channels, out_channels | kernel_spatials] = Nx.axes(kernel)
367378

@@ -372,57 +383,57 @@ defmodule Bumblebee.Conversion.PyTorch do
372383
end}
373384
end
374385

375-
defp default_layer_param_source(%{op_name: :lstm}, "bias", layer_name) do
376-
{[{layer_name, "bias_hh"}, {layer_name, "bias_ih"}],
386+
defp default_layer_param_builder(%{op_name: :lstm}, "bias", layer_name) do
387+
{[param_refs(layer_name, "bias_hh"), param_refs(layer_name, "bias_ih")],
377388
fn [bias_hh, bias_ih] ->
378389
bias = Nx.add(bias_ih, bias_hh)
379390
bias = Nx.reshape(bias, {4, :auto})
380391
{bias[0], bias[1], bias[2], bias[3]}
381392
end}
382393
end
383394

384-
defp default_layer_param_source(%{op_name: :lstm}, "input_kernel", layer_name) do
385-
{[{layer_name, "weight_ih"}],
395+
defp default_layer_param_builder(%{op_name: :lstm}, "input_kernel", layer_name) do
396+
{[param_refs(layer_name, "weight_ih")],
386397
fn [weight_ih] ->
387398
weight_ih = weight_ih |> unflatten_leading(4) |> Nx.transpose(axes: [0, 2, 1])
388399
{weight_ih[0], weight_ih[1], weight_ih[2], weight_ih[3]}
389400
end}
390401
end
391402

392-
defp default_layer_param_source(%{op_name: :lstm}, "hidden_kernel", layer_name) do
393-
{[{layer_name, "weight_hh"}],
403+
defp default_layer_param_builder(%{op_name: :lstm}, "hidden_kernel", layer_name) do
404+
{[param_refs(layer_name, "weight_hh")],
394405
fn [weight_hh] ->
395406
weight_hh = weight_hh |> unflatten_leading(4) |> Nx.transpose(axes: [0, 2, 1])
396407
{weight_hh[0], weight_hh[1], weight_hh[2], weight_hh[3]}
397408
end}
398409
end
399410

400-
defp default_layer_param_source(%{op_name: :gru}, "bias", layer_name) do
401-
{[{layer_name, "bias_hh"}, {layer_name, "bias_ih"}],
411+
defp default_layer_param_builder(%{op_name: :gru}, "bias", layer_name) do
412+
{[param_refs(layer_name, "bias_hh"), param_refs(layer_name, "bias_ih")],
402413
fn [bias_hh, bias_ih] ->
403414
bias_hh = unflatten_leading(bias_hh, 3)
404415
bias_ih = unflatten_leading(bias_ih, 3)
405416
{Nx.add(bias_ih[0], bias_hh[0]), Nx.add(bias_ih[1], bias_hh[1]), bias_ih[2], bias_hh[2]}
406417
end}
407418
end
408419

409-
defp default_layer_param_source(%{op_name: :gru}, "input_kernel", layer_name) do
410-
{[{layer_name, "weight_ih"}],
420+
defp default_layer_param_builder(%{op_name: :gru}, "input_kernel", layer_name) do
421+
{[param_refs(layer_name, "weight_ih")],
411422
fn [weight_ih] ->
412423
weight_ih = weight_ih |> unflatten_leading(4) |> Nx.transpose(axes: [0, 2, 1])
413424
{weight_ih[0], weight_ih[1], weight_ih[2]}
414425
end}
415426
end
416427

417-
defp default_layer_param_source(%{op_name: :gru}, "hidden_kernel", layer_name) do
418-
{[{layer_name, "weight_hh"}],
428+
defp default_layer_param_builder(%{op_name: :gru}, "hidden_kernel", layer_name) do
429+
{[param_refs(layer_name, "weight_hh")],
419430
fn [weight_hh] ->
420431
weight_hh = weight_hh |> unflatten_leading(3) |> Nx.transpose(axes: [0, 2, 1])
421432
{weight_hh[0], weight_hh[1], weight_hh[2]}
422433
end}
423434
end
424435

425-
defp default_layer_param_source(_layer, param_name, layer_name) do
436+
defp default_layer_param_builder(_layer, param_name, layer_name) do
426437
pytorch_names =
427438
case param_name do
428439
# PyTorch uses "weight" instead of "kernel" everywhere
@@ -440,18 +451,26 @@ defmodule Bumblebee.Conversion.PyTorch do
440451
name -> [name]
441452
end
442453

443-
{[{layer_name, pytorch_names}], fn [value] -> value end}
454+
param_source = Enum.flat_map(pytorch_names, &param_refs(layer_name, &1))
455+
456+
{[param_source], fn [value] -> value end}
457+
end
458+
459+
defp param_refs(layer_name, param_name) do
460+
for layer_name <- List.wrap(layer_name) do
461+
{layer_name, param_name}
462+
end
444463
end
445464

446-
defp lookup_param(pytorch_state, layer_name, pytorch_names) do
465+
defp lookup_param(pytorch_state, layer_name, pytorch_name) do
447466
# Note: the PyTorch model may have some root-level parameters that
448-
# we need to namespace under a layer in Axon, so after trying params
449-
# within layer_name, we also try the parameter name directly
450-
pytorch_keys = Enum.map(pytorch_names, &(layer_name <> "." <> &1)) ++ pytorch_names
467+
# we need to namespace under a layer in Axon, so after trying the
468+
# param within layer_name, we also try the param name directly
469+
pytorch_keys = [layer_name <> "." <> pytorch_name, pytorch_name]
451470

452-
Enum.find_value(pytorch_keys, :error, fn pytorch_key ->
471+
Enum.find_value(pytorch_keys, fn pytorch_key ->
453472
if value = pytorch_state[pytorch_key] do
454-
{:ok, value, pytorch_key}
473+
{value, pytorch_key}
455474
end
456475
end)
457476
end

lib/bumblebee/diffusion/unet_2d_conditional.ex

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,8 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do
407407
end
408408

409409
defimpl Bumblebee.HuggingFace.Transformers.Model do
410+
alias Bumblebee.HuggingFace.Transformers
411+
410412
def params_mapping(_spec) do
411413
block_mapping = %{
412414
"transformers.{m}.norm" => "attentions.{m}.norm",
@@ -449,10 +451,9 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do
449451
}
450452

451453
blocks_mapping =
452-
for {target, source} <- block_mapping,
453-
prefix <- ["down_blocks.{n}", "mid_block", "up_blocks.{n}"],
454-
do: {prefix <> "." <> target, prefix <> "." <> source},
455-
into: %{}
454+
["down_blocks.{n}", "mid_block", "up_blocks.{n}"]
455+
|> Enum.map(&Transformers.Utils.prefix_params_mapping(block_mapping, &1, &1))
456+
|> Enum.reduce(&Map.merge/2)
456457

457458
%{
458459
"time_embedding.intermediate" => "time_embedding.linear_1",

lib/bumblebee/diffusion/vae_kl.ex

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -443,13 +443,16 @@ defmodule Bumblebee.Diffusion.VaeKl do
443443
end
444444

445445
defimpl Bumblebee.HuggingFace.Transformers.Model do
446+
alias Bumblebee.HuggingFace.Transformers
447+
446448
def params_mapping(_spec) do
447449
block_mapping = %{
448450
"attentions.{m}.norm" => "attentions.{m}.group_norm",
449-
"attentions.{m}.query" => "attentions.{m}.query",
450-
"attentions.{m}.key" => "attentions.{m}.key",
451-
"attentions.{m}.value" => "attentions.{m}.value",
452-
"attentions.{m}.output" => "attentions.{m}.proj_attn",
451+
# The layer name has been renamed upstream, so we try both
452+
"attentions.{m}.query" => ["attentions.{m}.to_q", "attentions.{m}.query"],
453+
"attentions.{m}.key" => ["attentions.{m}.to_k", "attentions.{m}.key"],
454+
"attentions.{m}.value" => ["attentions.{m}.to_v", "attentions.{m}.value"],
455+
"attentions.{m}.output" => ["attentions.{m}.to_out.0", "attentions.{m}.proj_attn"],
453456
"residual_blocks.{m}.norm_1" => "resnets.{m}.norm1",
454457
"residual_blocks.{m}.conv_1" => "resnets.{m}.conv1",
455458
"residual_blocks.{m}.norm_2" => "resnets.{m}.norm2",
@@ -460,15 +463,14 @@ defmodule Bumblebee.Diffusion.VaeKl do
460463
}
461464

462465
blocks_mapping =
463-
for {target, source} <- block_mapping,
464-
prefix <- [
465-
"encoder.down_blocks.{n}",
466-
"encoder.mid_block",
467-
"decoder.mid_block",
468-
"decoder.up_blocks.{n}"
469-
],
470-
do: {prefix <> "." <> target, prefix <> "." <> source},
471-
into: %{}
466+
[
467+
"encoder.down_blocks.{n}",
468+
"encoder.mid_block",
469+
"decoder.mid_block",
470+
"decoder.up_blocks.{n}"
471+
]
472+
|> Enum.map(&Transformers.Utils.prefix_params_mapping(block_mapping, &1, &1))
473+
|> Enum.reduce(&Map.merge/2)
472474

473475
%{
474476
"encoder.input_conv" => "encoder.conv_in",

lib/bumblebee/huggingface/transformers/model.ex

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@ defprotocol Bumblebee.HuggingFace.Transformers.Model do
44
# This protocol defines details related to loading Bumblebee model
55
# from huggingface/transformers model.
66

7-
@type params_mapping :: %{
8-
layer_name() => layer_name() | params_source()
9-
}
7+
@type params_mapping :: %{layer_name() => params_source()}
108

11-
@type params_source :: %{
12-
param_name() =>
13-
{list(source()), (list(Nx.tensor()) -> Nx.Tensor.t() | Nx.Container.t())}
14-
}
9+
@type params_source :: layer_name() | list(layer_name()) | param_builders()
1510

16-
@type source :: {layer_name(), param_name() | list(param_name())}
11+
@type param_builders :: %{param_name() => param_builder()}
12+
13+
@type param_builder ::
14+
{list(param_source()), (list(Nx.tensor()) -> Nx.Tensor.t() | Nx.Container.t())}
15+
16+
@type param_source :: param_ref() | list(param_ref())
17+
@type param_ref :: {layer_name(), param_name()}
1718

1819
@type layer_name :: String.t()
1920
@type param_name :: String.t()
@@ -53,8 +54,8 @@ defprotocol Bumblebee.HuggingFace.Transformers.Model do
5354
automatically.
5455
5556
In some cases, particularly with model-specific layers/parameters,
56-
we may need more control over the parameter mapping. In such cases, instead
57-
of source layer name, a map with parameter-level transformations
57+
we may need more control over the parameter mapping. In such cases,
58+
instead of source layer name, a map with parameter-level transformations
5859
may be specified:
5960
6061
%{
@@ -69,9 +70,10 @@ defprotocol Bumblebee.HuggingFace.Transformers.Model do
6970
7071
For each parameter, we specify a list of source parameters in the
7172
form of `{source_layer_name, source_param_name}`, then a function
72-
to build our parameter value. Multiple source parameter names to
73-
try may be specified. With the explicit transformation we can
74-
handle arbitrary parameter name and value transformations.
73+
to build our parameter value. Instead of a single tuple, we can
74+
specify a list of those to try one by one. With the explicit
75+
transformation we can handle arbitrary parameter name and value
76+
transformations.
7577
"""
7678
@spec params_mapping(t()) :: params_mapping()
7779
def params_mapping(spec)

lib/bumblebee/huggingface/transformers/utils.ex

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,28 @@ defmodule Bumblebee.HuggingFace.Transformers.Utils do
2828
@spec map_params_source_layer_names(
2929
Transformers.Model.params_source(),
3030
(String.t() -> String.t())
31-
) :: Transformers.Model.layer_name() | Transformers.Model.params_source()
32-
def map_params_source_layer_names(%{} = params_source, fun) do
33-
Map.new(params_source, fn {param_name, {sources, source_fun}} ->
34-
sources = for {layer_name, param_name} <- sources, do: {fun.(layer_name), param_name}
35-
{param_name, {sources, source_fun}}
31+
) :: Transformers.Model.params_source()
32+
def map_params_source_layer_names(%{} = param_builders, fun) do
33+
Map.new(param_builders, fn {param_name, {sources, builder_fun}} ->
34+
sources =
35+
for ref_or_refs <- sources do
36+
case ref_or_refs do
37+
{layer_name, param_name} ->
38+
{fun.(layer_name), param_name}
39+
40+
refs ->
41+
for {layer_name, param_name} <- refs, do: {fun.(layer_name), param_name}
42+
end
43+
end
44+
45+
{param_name, {sources, builder_fun}}
3646
end)
3747
end
3848

49+
def map_params_source_layer_names(layer_names, fun) when is_list(layer_names) do
50+
Enum.map(layer_names, fun)
51+
end
52+
3953
def map_params_source_layer_names(layer_name, fun) when is_binary(layer_name) do
4054
fun.(layer_name)
4155
end

0 commit comments

Comments
 (0)