Skip to content

Commit a4a78aa

Browse files
torfjeldegithub-actions[bot]devmotion
authored
Proper support for kwargs + splatting of args and kwargs (#477)
* initial work on simplification of compiler + supporting splatting of args and kwargs * fixed prob_macro * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixed doctests * dont use splat due to backwards compat * Update src/compiler.jl Co-authored-by: David Widmann <[email protected]> * added a make_evaluate_args_and_kwargs method * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * renamed build_model_info * bumped minor mode because Turing.jl breaks * improved docstring --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: David Widmann <[email protected]>
1 parent 58cdb12 commit a4a78aa

File tree

6 files changed

+108
-90
lines changed

6 files changed

+108
-90
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.22.4"
3+
version = "0.23.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/compiler.jl

+38-75
Original file line numberDiff line numberDiff line change
@@ -209,20 +209,21 @@ macro model(expr, warn=false)
209209
end
210210

211211
function model(mod, linenumbernode, expr, warn)
212-
modelinfo = build_model_info(expr)
212+
modeldef = build_model_definition(expr)
213213

214214
# Generate main body
215-
modelinfo[:body] = generate_mainbody(mod, modelinfo[:modeldef][:body], warn)
215+
modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn)
216216

217-
return build_output(modelinfo, linenumbernode)
217+
return build_output(modeldef, linenumbernode)
218218
end
219219

220220
"""
221-
build_model_info(input_expr)
221+
build_model_definition(input_expr)
222222
223-
Builds the `model_info` dictionary from the model's expression.
223+
Builds the `modeldef` dictionary from the model's expression, where
224+
`modeldef` is a dictionary compatible with `MacroTools.combinedef`.
224225
"""
225-
function build_model_info(input_expr)
226+
function build_model_definition(input_expr)
226227
# Break up the model definition and extract its name, arguments, and function body
227228
modeldef = MacroTools.splitdef(input_expr)
228229

@@ -238,66 +239,13 @@ function build_model_info(input_expr)
238239

239240
# Shortcut if the model does not have any arguments
240241
if !haskey(modeldef, :args) && !haskey(modeldef, :kwargs)
241-
modelinfo = Dict(
242-
:allargs_exprs => [],
243-
:allargs_syms => [],
244-
:allargs_namedtuple => NamedTuple(),
245-
:defaults_namedtuple => NamedTuple(),
246-
:modeldef => modeldef,
247-
)
248-
return modelinfo
242+
return modeldef
249243
end
250244

251245
# Ensure that all arguments have a name, i.e., are of the form `name` or `name::T`
252246
addargnames!(modeldef[:args])
253247

254-
# Extract the positional and keyword arguments from the model definition.
255-
allargs = vcat(modeldef[:args], modeldef[:kwargs])
256-
257-
# Split the argument expressions and the default values.
258-
allargs_exprs_defaults = map(allargs) do arg
259-
MacroTools.@match arg begin
260-
(x_ = val_) => (x, val)
261-
x_ => (x, NO_DEFAULT)
262-
end
263-
end
264-
265-
# Extract the expressions of the arguments, without default values.
266-
allargs_exprs = first.(allargs_exprs_defaults)
267-
268-
# Extract the names of the arguments.
269-
allargs_syms = map(allargs_exprs) do arg
270-
MacroTools.@match arg begin
271-
(name_::_) => name
272-
x_ => x
273-
end
274-
end
275-
276-
# Build named tuple expression of the argument symbols and variables of the same name.
277-
allargs_namedtuple = to_namedtuple_expr(allargs_syms)
278-
279-
# Extract default values of the positional and keyword arguments.
280-
default_syms = []
281-
default_vals = []
282-
for (sym, (expr, val)) in zip(allargs_syms, allargs_exprs_defaults)
283-
if val !== NO_DEFAULT
284-
push!(default_syms, sym)
285-
push!(default_vals, val)
286-
end
287-
end
288-
289-
# Build named tuple expression of the argument symbols with default values.
290-
defaults_namedtuple = to_namedtuple_expr(default_syms)
291-
292-
modelinfo = Dict(
293-
:allargs_exprs => allargs_exprs,
294-
:allargs_syms => allargs_syms,
295-
:allargs_namedtuple => allargs_namedtuple,
296-
:defaults_namedtuple => defaults_namedtuple,
297-
:modeldef => modeldef,
298-
)
299-
300-
return modelinfo
248+
return modeldef
301249
end
302250

303251
"""
@@ -561,14 +509,32 @@ hasmissing(::Type{>:Missing}) = true
561509
hasmissing(::Type{<:AbstractArray{TA}}) where {TA} = hasmissing(TA)
562510
hasmissing(::Type{Union{}}) = false # issue #368
563511

512+
function splitarg_to_expr((arg_name, arg_type, is_splat, default))
513+
return is_splat ? :($arg_name...) : arg_name
514+
end
515+
516+
function namedtuple_from_splitargs(splitargs)
517+
names = map(splitargs) do (arg_name, arg_type, is_splat, default)
518+
is_splat ? Symbol("#splat#$(arg_name)") : arg_name
519+
end
520+
names_expr = Expr(:tuple, map(QuoteNode, names)...)
521+
vals = Expr(:tuple, map(first, splitargs)...)
522+
return :(NamedTuple{$names_expr}($vals))
523+
end
524+
525+
is_splat_symbol(s::Symbol) = startswith(string(s), "#splat#")
526+
564527
"""
565-
build_output(modelinfo, linenumbernode)
528+
build_output(modeldef, linenumbernode)
566529
567530
Builds the output expression.
568531
"""
569-
function build_output(modelinfo, linenumbernode)
532+
function build_output(modeldef, linenumbernode)
533+
args = modeldef[:args]
534+
kwargs = modeldef[:kwargs]
535+
570536
## Build the anonymous evaluator from the user-provided model definition.
571-
evaluatordef = deepcopy(modelinfo[:modeldef])
537+
evaluatordef = deepcopy(modeldef)
572538

573539
# Add the internal arguments to the user-specified arguments (positional + keywords).
574540
evaluatordef[:args] = vcat(
@@ -577,12 +543,9 @@ function build_output(modelinfo, linenumbernode)
577543
:(__varinfo__::$(DynamicPPL.AbstractVarInfo)),
578544
:(__context__::$(DynamicPPL.AbstractContext)),
579545
],
580-
modelinfo[:allargs_exprs],
546+
args,
581547
)
582548

583-
# Delete the keyword arguments.
584-
evaluatordef[:kwargs] = []
585-
586549
# Replace the user-provided function body with the version created by DynamicPPL.
587550
# We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure
588551
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
@@ -593,18 +556,13 @@ function build_output(modelinfo, linenumbernode)
593556
# See the docstrings of `replace_returns` for more info.
594557
evaluatordef[:body] = MacroTools.@q begin
595558
$(linenumbernode)
596-
$(replace_returns(make_returns_explicit!(modelinfo[:body])))
559+
$(replace_returns(make_returns_explicit!(modeldef[:body])))
597560
end
598561

599562
## Build the model function.
600563

601-
# Extract the named tuple expression of all arguments and the default values.
602-
allargs_namedtuple = modelinfo[:allargs_namedtuple]
603-
defaults_namedtuple = modelinfo[:defaults_namedtuple]
604-
605564
# Obtain or generate the name of the model to support functors:
606565
# https://github.com/TuringLang/DynamicPPL.jl/issues/367
607-
modeldef = modelinfo[:modeldef]
608566
if MacroTools.@capture(modeldef[:name], ::T_)
609567
name = gensym(:f)
610568
modeldef[:name] = Expr(:(::), name, T)
@@ -613,13 +571,18 @@ function build_output(modelinfo, linenumbernode)
613571
throw(ArgumentError("unsupported format of model function"))
614572
end
615573

574+
args_split = map(MacroTools.splitarg, args)
575+
kwargs_split = map(MacroTools.splitarg, kwargs)
576+
args_nt = namedtuple_from_splitargs(args_split)
577+
kwargs_inclusion = map(splitarg_to_expr, kwargs_split)
578+
616579
# Update the function body of the user-specified model.
617580
# We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure
618581
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
619582
# to the call site
620583
modeldef[:body] = MacroTools.@q begin
621584
$(linenumbernode)
622-
return $(DynamicPPL.Model)($name, $allargs_namedtuple, $defaults_namedtuple)
585+
return $(DynamicPPL.Model)($name, $args_nt; $(kwargs_inclusion...))
623586
end
624587

625588
return MacroTools.@q begin

src/model.jl

+28-7
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,17 @@ model with different arguments.
6767
@generated function Model(
6868
f::F,
6969
args::NamedTuple{argnames,Targs},
70-
defaults::NamedTuple=NamedTuple(),
70+
defaults::NamedTuple,
7171
context::AbstractContext=DefaultContext(),
7272
) where {F,argnames,Targs}
7373
missings = Tuple(name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing)
7474
return :(Model{$missings}(f, args, defaults, context))
7575
end
7676

77+
function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs...)
78+
return Model(f, args, NamedTuple(kwargs), context)
79+
end
80+
7781
function contextualize(model::Model, context::AbstractContext)
7882
return Model(model.f, model.args, model.defaults, context)
7983
end
@@ -177,7 +181,7 @@ julia> @model function demo_mv(::Type{TV}=Float64) where {TV}
177181
m[2] ~ Normal()
178182
return m
179183
end
180-
demo_mv (generic function with 3 methods)
184+
demo_mv (generic function with 4 methods)
181185
182186
julia> model = demo_mv();
183187
@@ -376,7 +380,7 @@ julia> @model function demo_mv(::Type{TV}=Float64) where {TV}
376380
m[2] ~ Normal()
377381
return m
378382
end
379-
demo_mv (generic function with 3 methods)
383+
demo_mv (generic function with 4 methods)
380384
381385
julia> model = demo_mv();
382386
@@ -573,12 +577,27 @@ end
573577
574578
Evaluate the `model` with the arguments matching the given `context` and `varinfo` object.
575579
"""
576-
@generated function _evaluate!!(
577-
model::Model{_F,argnames}, varinfo, context
580+
function _evaluate!!(model::Model, varinfo::AbstractVarInfo, context::AbstractContext)
581+
args, kwargs = make_evaluate_args_and_kwargs(model, varinfo, context)
582+
return model.f(args...; kwargs...)
583+
end
584+
585+
"""
586+
make_evaluate_args_and_kwargs(model, varinfo, context)
587+
588+
Return the arguments and keyword arguments to be passed to the evaluator of the model, i.e. `model.f`e.
589+
"""
590+
@generated function make_evaluate_args_and_kwargs(
591+
model::Model{_F,argnames}, varinfo::AbstractVarInfo, context::AbstractContext
578592
) where {_F,argnames}
579593
unwrap_args = [
580-
:($matchingvalue(context_new, varinfo, model.args.$var)) for var in argnames
594+
if is_splat_symbol(var)
595+
:($matchingvalue(context_new, varinfo, model.args.$var)...)
596+
else
597+
:($matchingvalue(context_new, varinfo, model.args.$var))
598+
end for var in argnames
581599
]
600+
582601
# We want to give `context` precedence over `model.context` while also
583602
# preserving the leaf context of `context`. We can do this by
584603
# 1. Set the leaf context of `model.context` to `leafcontext(context)`.
@@ -590,7 +609,7 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf
590609
context_new = setleafcontext(
591610
context, setleafcontext(model.context, leafcontext(context))
592611
)
593-
model.f(
612+
args = (
594613
model,
595614
# Maybe perform `invlink!!` once prior to evaluation to avoid
596615
# lazy `invlink`-ing of the parameters. This can be useful for
@@ -600,6 +619,8 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf
600619
context_new,
601620
$(unwrap_args...),
602621
)
622+
kwargs = model.defaults
623+
return args, kwargs
603624
end
604625
end
605626

src/prob_macro.jl

+6-2
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ function probtype(
9292
return getfield(right, arg)
9393
elseif arg in defaultnames
9494
return getfield(defaults, arg)
95+
elseif arg in argnames
96+
return getfield(model.args, arg)
9597
else
9698
return nothing
9799
end
@@ -170,7 +172,7 @@ end
170172
push!(argvals, :(model.defaults.$argname))
171173
else
172174
push!(warnings, :(@warn($(warn_msg(argname)))))
173-
push!(argvals, :(nothing))
175+
push!(argvals, :(model.args.$argname))
174176
end
175177
end
176178

@@ -184,7 +186,7 @@ end
184186
end
185187
end
186188

187-
warn_msg(arg) = "Argument $arg is not defined. A value of `nothing` is used."
189+
warn_msg(arg) = "Argument $arg is not defined. Using the value from the model."
188190

189191
function Distributions.loglikelihood(
190192
left::NamedTuple, right::NamedTuple, _model::Model, _vi::Union{Nothing,VarInfo}
@@ -227,6 +229,8 @@ end
227229
push!(missings, argname)
228230
elseif argname in defaultnames
229231
push!(argvals, :(model.defaults.$argname))
232+
elseif argname in argnames
233+
push!(argvals, :(model.args.$argname))
230234
else
231235
throw(
232236
"This point should not be reached. Please open an issue in the DynamicPPL.jl repository.",

src/varname.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ Statically check whether the variable of name `varname` is an argument of the `m
2222
2323
Possibly existing indices of `varname` are neglected.
2424
"""
25-
@generated function inargnames(::VarName{s}, ::Model{_F,argnames}) where {s,argnames,_F}
26-
return s in argnames
25+
@generated function inargnames(
26+
::VarName{s}, ::Model{_F,argnames,defaultnames}
27+
) where {s,argnames,defaultnames,_F}
28+
return s in argnames || s in defaultnames
2729
end
2830

2931
"""

test/compiler.jl

+31-3
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ end
5252
x ~ Normal()
5353
return x
5454
end
55-
@test length(methods(testmodel01)) == 3
55+
@test length(methods(testmodel01)) == 4
5656
f0_mm = testmodel01()
5757
@test mean(f0_mm() for _ in 1:1000) 0.0 atol = 0.1
5858

@@ -65,7 +65,7 @@ end
6565
x[2] ~ Normal()
6666
return x
6767
end
68-
@test length(methods(testmodel02)) == 3
68+
@test length(methods(testmodel02)) == 4
6969
f0_mm = testmodel02()
7070
@test all(x -> isapprox(x, 0; atol=0.1), mean(f0_mm() for _ in 1:1000))
7171

@@ -74,7 +74,7 @@ end
7474
return x
7575
end
7676
f01_mm = testmodel03()
77-
@test length(methods(testmodel03)) == 3
77+
@test length(methods(testmodel03)) == 4
7878
@test mean(f01_mm() for _ in 1:1000) 0.5 atol = 0.1
7979

8080
# test if we get the correct return values
@@ -620,4 +620,32 @@ end
620620
@test f_393()() == 1
621621
@test f_393(Val(true))() == 0
622622
end
623+
624+
@testset "splatting of args and kwargs" begin
625+
@model function f_splat_test_1(x; y::T=1, kwargs...) where {T}
626+
x ~ Normal(y, 1)
627+
return x, y, T, NamedTuple(kwargs)
628+
end
629+
630+
# Non-empty `kwargs...`.
631+
res = f_splat_test_1(1; z=2, w=3)()
632+
@test res == (1, 1, Int, (z=2, w=3))
633+
634+
# Empty `kwargs...`.
635+
res = f_splat_test_1(1)()
636+
@test res == (1, 1, Int, NamedTuple())
637+
638+
@model function f_splat_test_2(x, args...; y::T=1, kwargs...) where {T}
639+
x ~ Normal(y, 1)
640+
return x, args, y, T, NamedTuple(kwargs)
641+
end
642+
643+
# Non-empty `args...` and non-empty `kwargs...`.
644+
res = f_splat_test_2(1, 2, 3; z=2, w=3)()
645+
@test res == (1, (2, 3), 1, Int, (z=2, w=3))
646+
647+
# Empty `args...` and empty `kwargs...`.
648+
res = f_splat_test_2(1)()
649+
@test res == (1, (), 1, Int, NamedTuple())
650+
end
623651
end

0 commit comments

Comments
 (0)