@@ -209,20 +209,21 @@ macro model(expr, warn=false)
209
209
end
210
210
211
211
function model (mod, linenumbernode, expr, warn)
212
- modelinfo = build_model_info (expr)
212
+ modeldef = build_model_definition (expr)
213
213
214
214
# Generate main body
215
- modelinfo [:body ] = generate_mainbody (mod, modelinfo[ : modeldef] [:body ], warn)
215
+ modeldef [:body ] = generate_mainbody (mod, modeldef[:body ], warn)
216
216
217
- return build_output (modelinfo , linenumbernode)
217
+ return build_output (modeldef , linenumbernode)
218
218
end
219
219
220
220
"""
221
- build_model_info (input_expr)
221
+ build_model_definition (input_expr)
222
222
223
- Builds the `model_info` dictionary from the model's expression.
223
+ Builds the `modeldef` dictionary from the model's expression, where
224
+ `modeldef` is a dictionary compatible with `MacroTools.combinedef`.
224
225
"""
225
- function build_model_info (input_expr)
226
+ function build_model_definition (input_expr)
226
227
# Break up the model definition and extract its name, arguments, and function body
227
228
modeldef = MacroTools. splitdef (input_expr)
228
229
@@ -238,66 +239,13 @@ function build_model_info(input_expr)
238
239
239
240
# Shortcut if the model does not have any arguments
240
241
if ! haskey (modeldef, :args ) && ! haskey (modeldef, :kwargs )
241
- modelinfo = Dict (
242
- :allargs_exprs => [],
243
- :allargs_syms => [],
244
- :allargs_namedtuple => NamedTuple (),
245
- :defaults_namedtuple => NamedTuple (),
246
- :modeldef => modeldef,
247
- )
248
- return modelinfo
242
+ return modeldef
249
243
end
250
244
251
245
# Ensure that all arguments have a name, i.e., are of the form `name` or `name::T`
252
246
addargnames! (modeldef[:args ])
253
247
254
- # Extract the positional and keyword arguments from the model definition.
255
- allargs = vcat (modeldef[:args ], modeldef[:kwargs ])
256
-
257
- # Split the argument expressions and the default values.
258
- allargs_exprs_defaults = map (allargs) do arg
259
- MacroTools. @match arg begin
260
- (x_ = val_) => (x, val)
261
- x_ => (x, NO_DEFAULT)
262
- end
263
- end
264
-
265
- # Extract the expressions of the arguments, without default values.
266
- allargs_exprs = first .(allargs_exprs_defaults)
267
-
268
- # Extract the names of the arguments.
269
- allargs_syms = map (allargs_exprs) do arg
270
- MacroTools. @match arg begin
271
- (name_:: _ ) => name
272
- x_ => x
273
- end
274
- end
275
-
276
- # Build named tuple expression of the argument symbols and variables of the same name.
277
- allargs_namedtuple = to_namedtuple_expr (allargs_syms)
278
-
279
- # Extract default values of the positional and keyword arguments.
280
- default_syms = []
281
- default_vals = []
282
- for (sym, (expr, val)) in zip (allargs_syms, allargs_exprs_defaults)
283
- if val != = NO_DEFAULT
284
- push! (default_syms, sym)
285
- push! (default_vals, val)
286
- end
287
- end
288
-
289
- # Build named tuple expression of the argument symbols with default values.
290
- defaults_namedtuple = to_namedtuple_expr (default_syms)
291
-
292
- modelinfo = Dict (
293
- :allargs_exprs => allargs_exprs,
294
- :allargs_syms => allargs_syms,
295
- :allargs_namedtuple => allargs_namedtuple,
296
- :defaults_namedtuple => defaults_namedtuple,
297
- :modeldef => modeldef,
298
- )
299
-
300
- return modelinfo
248
+ return modeldef
301
249
end
302
250
303
251
"""
@@ -561,14 +509,32 @@ hasmissing(::Type{>:Missing}) = true
561
509
hasmissing (:: Type{<:AbstractArray{TA}} ) where {TA} = hasmissing (TA)
562
510
hasmissing (:: Type{Union{}} ) = false # issue #368
563
511
512
+ function splitarg_to_expr ((arg_name, arg_type, is_splat, default))
513
+ return is_splat ? :($ arg_name... ) : arg_name
514
+ end
515
+
516
+ function namedtuple_from_splitargs (splitargs)
517
+ names = map (splitargs) do (arg_name, arg_type, is_splat, default)
518
+ is_splat ? Symbol (" #splat#$(arg_name) " ) : arg_name
519
+ end
520
+ names_expr = Expr (:tuple , map (QuoteNode, names)... )
521
+ vals = Expr (:tuple , map (first, splitargs)... )
522
+ return :(NamedTuple {$names_expr} ($ vals))
523
+ end
524
+
525
+ is_splat_symbol (s:: Symbol ) = startswith (string (s), " #splat#" )
526
+
564
527
"""
565
- build_output(modelinfo , linenumbernode)
528
+ build_output(modeldef , linenumbernode)
566
529
567
530
Builds the output expression.
568
531
"""
569
- function build_output (modelinfo, linenumbernode)
532
+ function build_output (modeldef, linenumbernode)
533
+ args = modeldef[:args ]
534
+ kwargs = modeldef[:kwargs ]
535
+
570
536
# # Build the anonymous evaluator from the user-provided model definition.
571
- evaluatordef = deepcopy (modelinfo[ : modeldef] )
537
+ evaluatordef = deepcopy (modeldef)
572
538
573
539
# Add the internal arguments to the user-specified arguments (positional + keywords).
574
540
evaluatordef[:args ] = vcat (
@@ -577,12 +543,9 @@ function build_output(modelinfo, linenumbernode)
577
543
:(__varinfo__:: $ (DynamicPPL. AbstractVarInfo)),
578
544
:(__context__:: $ (DynamicPPL. AbstractContext)),
579
545
],
580
- modelinfo[ :allargs_exprs ] ,
546
+ args ,
581
547
)
582
548
583
- # Delete the keyword arguments.
584
- evaluatordef[:kwargs ] = []
585
-
586
549
# Replace the user-provided function body with the version created by DynamicPPL.
587
550
# We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure
588
551
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
@@ -593,18 +556,13 @@ function build_output(modelinfo, linenumbernode)
593
556
# See the docstrings of `replace_returns` for more info.
594
557
evaluatordef[:body ] = MacroTools. @q begin
595
558
$ (linenumbernode)
596
- $ (replace_returns (make_returns_explicit! (modelinfo [:body ])))
559
+ $ (replace_returns (make_returns_explicit! (modeldef [:body ])))
597
560
end
598
561
599
562
# # Build the model function.
600
563
601
- # Extract the named tuple expression of all arguments and the default values.
602
- allargs_namedtuple = modelinfo[:allargs_namedtuple ]
603
- defaults_namedtuple = modelinfo[:defaults_namedtuple ]
604
-
605
564
# Obtain or generate the name of the model to support functors:
606
565
# https://github.com/TuringLang/DynamicPPL.jl/issues/367
607
- modeldef = modelinfo[:modeldef ]
608
566
if MacroTools. @capture (modeldef[:name ], :: T_ )
609
567
name = gensym (:f )
610
568
modeldef[:name ] = Expr (:(:: ), name, T)
@@ -613,13 +571,18 @@ function build_output(modelinfo, linenumbernode)
613
571
throw (ArgumentError (" unsupported format of model function" ))
614
572
end
615
573
574
+ args_split = map (MacroTools. splitarg, args)
575
+ kwargs_split = map (MacroTools. splitarg, kwargs)
576
+ args_nt = namedtuple_from_splitargs (args_split)
577
+ kwargs_inclusion = map (splitarg_to_expr, kwargs_split)
578
+
616
579
# Update the function body of the user-specified model.
617
580
# We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure
618
581
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
619
582
# to the call site
620
583
modeldef[:body ] = MacroTools. @q begin
621
584
$ (linenumbernode)
622
- return $ (DynamicPPL. Model)($ name, $ allargs_namedtuple, $ defaults_namedtuple )
585
+ return $ (DynamicPPL. Model)($ name, $ args_nt; $ (kwargs_inclusion ... ) )
623
586
end
624
587
625
588
return MacroTools. @q begin
0 commit comments