@@ -106,35 +106,42 @@ defmodule Bumblebee.Conversion.PyTorch do
106
106
Enum . reduce ( layer . parameters , { [ ] , diff } , fn param , { params , diff } ->
107
107
param_expr = params_expr [ layer_name ] [ param . name ]
108
108
109
- { sources , source_fun } =
109
+ { sources , builder_fun } =
110
110
case params_source do
111
- % { } = layer_params_mapping ->
112
- if info = layer_params_mapping [ param . name ] do
113
- info
111
+ % { } = param_builders ->
112
+ if param_builder = param_builders [ param . name ] do
113
+ param_builder
114
114
else
115
- raise "no matching mapping found for parameter #{ inspect ( param . name ) } in #{ inspect ( layer_params_mapping ) } "
115
+ raise "no matching mapping found for parameter #{ inspect ( param . name ) } in #{ inspect ( param_builders ) } "
116
116
end
117
117
118
- source_layer_name when is_binary ( source_layer_name ) ->
119
- default_layer_param_source ( layer , param . name , source_layer_name )
118
+ source_layer_name
119
+ when is_binary ( source_layer_name ) or
120
+ is_list ( source_layer_name ) ->
121
+ default_layer_param_builder ( layer , param . name , source_layer_name )
120
122
end
121
123
122
124
{ all_sources_found? , source_values , source_keys } =
123
- for { source_layer_name , source_param_name } <- sources , reduce: { true , [ ] , [ ] } do
125
+ for source <- sources , reduce: { true , [ ] , [ ] } do
124
126
{ all_found? , values , keys } ->
125
- source_param_names = List . wrap ( source_param_name )
126
-
127
- case lookup_param ( pytorch_state , source_layer_name , source_param_names ) do
128
- { :ok , value , key } -> { all_found? , [ value | values ] , [ key | keys ] }
129
- :error -> { false , values , keys }
127
+ # Source can be either {layer_name, param_name}, or
128
+ # a list of these, to find any match
129
+ source
130
+ |> List . wrap ( )
131
+ |> Enum . find_value ( fn { source_layer_name , source_param_name } ->
132
+ lookup_param ( pytorch_state , source_layer_name , source_param_name )
133
+ end )
134
+ |> case do
135
+ { value , key } -> { all_found? , [ value | values ] , [ key | keys ] }
136
+ nil -> { false , values , keys }
130
137
end
131
138
end
132
139
133
140
diff = prepend ( diff , :used_keys , source_keys )
134
141
135
142
{ value , diff } =
136
143
if all_sources_found? do
137
- value = source_fun . ( Enum . reverse ( source_values ) )
144
+ value = builder_fun . ( Enum . reverse ( source_values ) )
138
145
139
146
case verify_param_shape ( param_expr , value ) do
140
147
:ok ->
@@ -186,11 +193,15 @@ defmodule Bumblebee.Conversion.PyTorch do
186
193
187
194
source_templates =
188
195
Enum . flat_map ( params_mapping , fn
189
- { _target_template , % { } = params_source } ->
190
- for { _target_param_name , { sources , _source_fun } } <- params_source ,
191
- { source_template , _source_param_name } <- sources ,
196
+ { _target_template , % { } = param_builders } ->
197
+ for { _target_param_name , { sources , _builder_fun } } <- param_builders ,
198
+ ref_or_refs <- sources ,
199
+ { source_template , _source_param_name } <- List . wrap ( ref_or_refs ) ,
192
200
do: source_template
193
201
202
+ { _target_template , source_templates } when is_list ( source_templates ) ->
203
+ source_templates
204
+
194
205
{ _target_template , source_template } when is_binary ( source_template ) ->
195
206
[ source_template ]
196
207
end )
@@ -339,17 +350,17 @@ defmodule Bumblebee.Conversion.PyTorch do
339
350
340
351
defp format_list ( items ) , do: Enum . map_join ( items , "\n " , & ( " * " <> & 1 ) )
341
352
342
- defp default_layer_param_source ( % { op_name: :dense } , "kernel" , layer_name ) do
343
- { [ { layer_name , "weight" } ] ,
353
+ defp default_layer_param_builder ( % { op_name: :dense } , "kernel" , layer_name ) do
354
+ { [ param_refs ( layer_name , "weight" ) ] ,
344
355
fn [ kernel ] ->
345
356
[ out_features , in_features ] = Nx . axes ( kernel )
346
357
Nx . transpose ( kernel , axes: [ in_features , out_features ] )
347
358
end }
348
359
end
349
360
350
- defp default_layer_param_source ( layer , "kernel" , layer_name )
361
+ defp default_layer_param_builder ( layer , "kernel" , layer_name )
351
362
when layer . op_name in [ :conv , :depthwise_conv ] do
352
- { [ { layer_name , "weight" } ] ,
363
+ { [ param_refs ( layer_name , "weight" ) ] ,
353
364
fn [ kernel ] ->
354
365
[ out_channels , in_channels | kernel_spatials ] = Nx . axes ( kernel )
355
366
@@ -360,8 +371,8 @@ defmodule Bumblebee.Conversion.PyTorch do
360
371
end }
361
372
end
362
373
363
- defp default_layer_param_source ( % { op_name: :conv_transpose } = layer , "kernel" , layer_name ) do
364
- { [ { layer_name , "weight" } ] ,
374
+ defp default_layer_param_builder ( % { op_name: :conv_transpose } = layer , "kernel" , layer_name ) do
375
+ { [ param_refs ( layer_name , "weight" ) ] ,
365
376
fn [ kernel ] ->
366
377
[ in_channels , out_channels | kernel_spatials ] = Nx . axes ( kernel )
367
378
@@ -372,57 +383,57 @@ defmodule Bumblebee.Conversion.PyTorch do
372
383
end }
373
384
end
374
385
375
- defp default_layer_param_source ( % { op_name: :lstm } , "bias" , layer_name ) do
376
- { [ { layer_name , "bias_hh" } , { layer_name , "bias_ih" } ] ,
386
+ defp default_layer_param_builder ( % { op_name: :lstm } , "bias" , layer_name ) do
387
+ { [ param_refs ( layer_name , "bias_hh" ) , param_refs ( layer_name , "bias_ih" ) ] ,
377
388
fn [ bias_hh , bias_ih ] ->
378
389
bias = Nx . add ( bias_ih , bias_hh )
379
390
bias = Nx . reshape ( bias , { 4 , :auto } )
380
391
{ bias [ 0 ] , bias [ 1 ] , bias [ 2 ] , bias [ 3 ] }
381
392
end }
382
393
end
383
394
384
- defp default_layer_param_source ( % { op_name: :lstm } , "input_kernel" , layer_name ) do
385
- { [ { layer_name , "weight_ih" } ] ,
395
+ defp default_layer_param_builder ( % { op_name: :lstm } , "input_kernel" , layer_name ) do
396
+ { [ param_refs ( layer_name , "weight_ih" ) ] ,
386
397
fn [ weight_ih ] ->
387
398
weight_ih = weight_ih |> unflatten_leading ( 4 ) |> Nx . transpose ( axes: [ 0 , 2 , 1 ] )
388
399
{ weight_ih [ 0 ] , weight_ih [ 1 ] , weight_ih [ 2 ] , weight_ih [ 3 ] }
389
400
end }
390
401
end
391
402
392
- defp default_layer_param_source ( % { op_name: :lstm } , "hidden_kernel" , layer_name ) do
393
- { [ { layer_name , "weight_hh" } ] ,
403
+ defp default_layer_param_builder ( % { op_name: :lstm } , "hidden_kernel" , layer_name ) do
404
+ { [ param_refs ( layer_name , "weight_hh" ) ] ,
394
405
fn [ weight_hh ] ->
395
406
weight_hh = weight_hh |> unflatten_leading ( 4 ) |> Nx . transpose ( axes: [ 0 , 2 , 1 ] )
396
407
{ weight_hh [ 0 ] , weight_hh [ 1 ] , weight_hh [ 2 ] , weight_hh [ 3 ] }
397
408
end }
398
409
end
399
410
400
- defp default_layer_param_source ( % { op_name: :gru } , "bias" , layer_name ) do
401
- { [ { layer_name , "bias_hh" } , { layer_name , "bias_ih" } ] ,
411
+ defp default_layer_param_builder ( % { op_name: :gru } , "bias" , layer_name ) do
412
+ { [ param_refs ( layer_name , "bias_hh" ) , param_refs ( layer_name , "bias_ih" ) ] ,
402
413
fn [ bias_hh , bias_ih ] ->
403
414
bias_hh = unflatten_leading ( bias_hh , 3 )
404
415
bias_ih = unflatten_leading ( bias_ih , 3 )
405
416
{ Nx . add ( bias_ih [ 0 ] , bias_hh [ 0 ] ) , Nx . add ( bias_ih [ 1 ] , bias_hh [ 1 ] ) , bias_ih [ 2 ] , bias_hh [ 2 ] }
406
417
end }
407
418
end
408
419
409
- defp default_layer_param_source ( % { op_name: :gru } , "input_kernel" , layer_name ) do
410
- { [ { layer_name , "weight_ih" } ] ,
420
+ defp default_layer_param_builder ( % { op_name: :gru } , "input_kernel" , layer_name ) do
421
+ { [ param_refs ( layer_name , "weight_ih" ) ] ,
411
422
fn [ weight_ih ] ->
412
423
weight_ih = weight_ih |> unflatten_leading ( 4 ) |> Nx . transpose ( axes: [ 0 , 2 , 1 ] )
413
424
{ weight_ih [ 0 ] , weight_ih [ 1 ] , weight_ih [ 2 ] }
414
425
end }
415
426
end
416
427
417
- defp default_layer_param_source ( % { op_name: :gru } , "hidden_kernel" , layer_name ) do
418
- { [ { layer_name , "weight_hh" } ] ,
428
+ defp default_layer_param_builder ( % { op_name: :gru } , "hidden_kernel" , layer_name ) do
429
+ { [ param_refs ( layer_name , "weight_hh" ) ] ,
419
430
fn [ weight_hh ] ->
420
431
weight_hh = weight_hh |> unflatten_leading ( 3 ) |> Nx . transpose ( axes: [ 0 , 2 , 1 ] )
421
432
{ weight_hh [ 0 ] , weight_hh [ 1 ] , weight_hh [ 2 ] }
422
433
end }
423
434
end
424
435
425
- defp default_layer_param_source ( _layer , param_name , layer_name ) do
436
+ defp default_layer_param_builder ( _layer , param_name , layer_name ) do
426
437
pytorch_names =
427
438
case param_name do
428
439
# PyTorch uses "weight" instead of "kernel" everywhere
@@ -440,18 +451,26 @@ defmodule Bumblebee.Conversion.PyTorch do
440
451
name -> [ name ]
441
452
end
442
453
443
- { [ { layer_name , pytorch_names } ] , fn [ value ] -> value end }
454
+ param_source = Enum . flat_map ( pytorch_names , & param_refs ( layer_name , & 1 ) )
455
+
456
+ { [ param_source ] , fn [ value ] -> value end }
457
+ end
458
+
459
+ defp param_refs ( layer_name , param_name ) do
460
+ for layer_name <- List . wrap ( layer_name ) do
461
+ { layer_name , param_name }
462
+ end
444
463
end
445
464
446
- defp lookup_param ( pytorch_state , layer_name , pytorch_names ) do
465
+ defp lookup_param ( pytorch_state , layer_name , pytorch_name ) do
447
466
# Note: the PyTorch model may have some root-level parameters that
448
- # we need to namespace under a layer in Axon, so after trying params
449
- # within layer_name, we also try the parameter name directly
450
- pytorch_keys = Enum . map ( pytorch_names , & ( layer_name <> "." <> & 1 ) ) ++ pytorch_names
467
+ # we need to namespace under a layer in Axon, so after trying the
468
+ # param within layer_name, we also try the param name directly
469
+ pytorch_keys = [ layer_name <> "." <> pytorch_name , pytorch_name ]
451
470
452
- Enum . find_value ( pytorch_keys , :error , fn pytorch_key ->
471
+ Enum . find_value ( pytorch_keys , fn pytorch_key ->
453
472
if value = pytorch_state [ pytorch_key ] do
454
- { :ok , value , pytorch_key }
473
+ { value , pytorch_key }
455
474
end
456
475
end )
457
476
end
0 commit comments