Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ DynamicPPLMooncakeExt = ["Mooncake"]
[compat]
ADTypes = "1"
AbstractMCMC = "5.10"
AbstractPPL = "0.13.1"
AbstractPPL = "0.14"
Accessors = "0.1"
BangBang = "0.4.1"
Bijectors = "0.15.11"
Expand Down Expand Up @@ -76,3 +76,6 @@ Random = "1.6"
Statistics = "1"
Test = "1.6"
julia = "1.10.8"

[sources]
AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "py/newvarname"}
3 changes: 3 additions & 0 deletions benchmarks/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "d94a1522-c11e-44a7-981a-42bf5dc1a001"
version = "0.1.0"

[deps]
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -18,9 +19,11 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

[sources]
AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "py/newvarname"}
DynamicPPL = {path = "../"}

[compat]
AbstractPPL = "0.14"
ADTypes = "1.14.0"
Chairmarks = "1.3.1"
Distributions = "0.25.117"
Expand Down
5 changes: 4 additions & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

[compat]
AbstractMCMC = "5"
AbstractPPL = "0.13"
AbstractPPL = "0.14"
Accessors = "0.1"
Distributions = "0.25"
Documenter = "1"
Expand All @@ -29,3 +29,6 @@ LogDensityProblems = "2"
MCMCChains = "5, 6, 7"
MarginalLogDensities = "0.4"
StableRNGs = "1"

[sources]
AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "py/newvarname"}
4 changes: 2 additions & 2 deletions ext/DynamicPPLMarginalLogDensitiesExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module DynamicPPLMarginalLogDensitiesExt

using DynamicPPL: DynamicPPL, LogDensityProblems, VarName
using DynamicPPL: DynamicPPL, LogDensityProblems, VarName, AbstractPPL
using MarginalLogDensities: MarginalLogDensities

# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by
Expand Down Expand Up @@ -105,7 +105,7 @@ function DynamicPPL.marginalize(
ldf = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo)
# Determine the indices for the variables to marginalise out.
varindices = mapreduce(vcat, marginalized_varnames) do vn
if DynamicPPL.getoptic(vn) === identity
if AbstractPPL.getoptic(vn) isa AbstractPPL.Iden
ldf._iden_varname_ranges[DynamicPPL.getsym(vn)].range
else
ldf._varname_ranges[vn].range
Expand Down
8 changes: 4 additions & 4 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ julia> values_as(SimpleVarInfo(data), NamedTuple)
(x = 1.0, m = [2.0])

julia> values_as(SimpleVarInfo(data), OrderedDict)
OrderedDict{VarName{sym, typeof(identity)} where sym, Any} with 2 entries:
OrderedDict{VarName{sym, AbstractPPL.Iden} where sym, Any} with 2 entries:
x => 1.0
m => [2.0]

Expand Down Expand Up @@ -564,7 +564,7 @@ julia> values_as(vi, NamedTuple)
(s = 1.0, m = 2.0)

julia> values_as(vi, OrderedDict)
OrderedDict{VarName{sym, typeof(identity)} where sym, Float64} with 2 entries:
OrderedDict{VarName{sym, AbstractPPL.Iden} where sym, Float64} with 2 entries:
s => 1.0
m => 2.0

Expand All @@ -590,7 +590,7 @@ julia> values_as(vi, NamedTuple)
(s = 1.0, m = 2.0)

julia> values_as(vi, OrderedDict)
OrderedDict{VarName{sym, typeof(identity)} where sym, Float64} with 2 entries:
OrderedDict{VarName{sym, AbstractPPL.Iden} where sym, Float64} with 2 entries:
s => 1.0
m => 2.0

Expand Down Expand Up @@ -680,7 +680,7 @@ julia> # Extract one with only `m`.


julia> keys(varinfo_subset1)
1-element Vector{VarName{:m, typeof(identity)}}:
1-element Vector{VarName{:m, AbstractPPL.Iden}}:
m

julia> varinfo_subset1[@varname(m)]
Expand Down
114 changes: 51 additions & 63 deletions src/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,44 +1,20 @@
const INTERNALNAMES = (:__model__, :__varinfo__)

"""
need_concretize(expr)

Return `true` if `expr` needs to be concretized, i.e., if it contains a colon `:` or
requires a dynamic optic.

# Examples

```jldoctest; setup=:(using Accessors)
julia> DynamicPPL.need_concretize(:(x[1, :]))
true

julia> DynamicPPL.need_concretize(:(x[1, end]))
true

julia> DynamicPPL.need_concretize(:(x[1, 1]))
false
"""
function need_concretize(expr)
return Accessors.need_dynamic_optic(expr) || begin
flag = false
MacroTools.postwalk(expr) do ex
# Concretise colon by default
ex == :(:) && (flag = true) && return ex
end
flag
end
drop_escape(x) = x
function drop_escape(expr::Expr)
Meta.isexpr(expr, :escape) && return drop_escape(expr.args[1])
return Expr(expr.head, map(x -> drop_escape(x), expr.args)...)
end

"""
make_varname_expression(expr)

Return a `VarName` based on `expr`, concretizing it if necessary.
Return a `VarName` based on `expr`.
"""
function make_varname_expression(expr)
# HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact
# that in DynamicPPL we the entire function body. Instead we should be
# more selective with our escape. Until that's the case, we remove them all.
return AbstractPPL.drop_escape(varname(expr, need_concretize(expr)))
# HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact that in
# DynamicPPL we escape the entire function body. Instead we should be more selective
# with our escape. Until that's the case, we remove them all.
return drop_escape(AbstractPPL.varname(expr, false))
end

"""
Expand All @@ -55,10 +31,9 @@ Let `expr` be `:(x[1])`. It is an assumption in the following cases:

When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.

If `vn` is specified, it will be assumed to refer to a expression which
evaluates to a `VarName`, and this will be used in the subsequent checks.
If `vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be
used in its place.
If `vn` is specified, it will be assumed to refer to a expression which evaluates to a
`VarName`, and this will be used in the subsequent checks. If `vn` is not specified,
`(@varname \$expr)` will be used in its place.
"""
function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr))
return quote
Expand Down Expand Up @@ -221,9 +196,6 @@ variables.

# Example
```jldoctest; setup=:(using Distributions, LinearAlgebra)
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(ones(2), I), randn(2, 2), @varname(x)); vns[end]
x[:, 2]

julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x)); vns[end]
x[1, 2]

Expand All @@ -241,31 +213,20 @@ end
function unwrap_right_left_vns(right::NamedDist, left::AbstractMatrix, ::VarName)
return unwrap_right_left_vns(right.dist, left, right.name)
end
function unwrap_right_left_vns(
right::MultivariateDistribution, left::AbstractMatrix, vn::VarName
)
# This an expression such as `x .~ MvNormal()` which we interpret as
# x[:, i] ~ MvNormal()
# for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`,
# and we therefore add the `Colon()` below.
vns = map(axes(left, 2)) do i
return AbstractPPL.concretize(Accessors.IndexLens((Colon(), i)) ∘ vn, left)
end
return unwrap_right_left_vns(right, left, vns)
end
function unwrap_right_left_vns(
right::Union{Distribution,AbstractArray{<:Distribution}},
left::AbstractArray,
vn::VarName,
)
vns = map(CartesianIndices(left)) do i
return Accessors.IndexLens(Tuple(i)) ∘ vn
sym, optic = getsym(vn), getoptic(vn)
return VarName{sym}(AbstractPPL.Index(Tuple(i), (;), AbstractPPL.Iden()) ∘ optic)
end
return unwrap_right_left_vns(right, left, vns)
end

resolve_varnames(vn::VarName, _) = vn
resolve_varnames(vn::VarName, dist::NamedDist) = dist.name
resolve_varnames(::VarName, dist::NamedDist) = dist.name

#################
# Main Compiler #
Expand Down Expand Up @@ -463,9 +424,18 @@ function generate_tilde_literal(left, right)
end
end

assign_or_set!!(lhs::Symbol, rhs) = AbstractPPL.drop_escape(:($lhs = $rhs))
function assign_or_set!!(lhs::Expr, rhs)
return AbstractPPL.drop_escape(:($BangBang.@set!! $lhs = $rhs))
assign_or_set!!(lhs::Symbol, rhs, vn) = drop_escape(:($lhs = $rhs))
function assign_or_set!!(lhs::Expr, rhs, vn)
left_top_sym = get_top_level_symbol(lhs)
return drop_escape(
:(
$left_top_sym = $(Accessors.set)(
$left_top_sym,
$(AbstractPPL.with_mutation)($(AbstractPPL.getoptic)($vn)),
$rhs,
)
),
)
end

"""
Expand All @@ -487,12 +457,13 @@ function generate_tilde(left, right)
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $(DynamicPPL.isfixed(left, vn))
# $left may not be a simple varname, it might be x.a or x[1], in which case we
# need to use BangBang.@set!! to safely set it.
# need to use Accessors.set to safely set it.
$(assign_or_set!!(
left,
:($(DynamicPPL.getfixed_nested)(
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
)),
vn,
))
elseif $isassumption
$(generate_tilde_assume(left, dist, vn))
Expand Down Expand Up @@ -520,22 +491,39 @@ function generate_tilde(left, right)
$vn,
__varinfo__,
)
$(assign_or_set!!(left, value))
$(assign_or_set!!(left, value, vn))
$value
end
end
end

get_top_level_symbol(expr::Symbol) = expr
function get_top_level_symbol(expr::Expr)
# TODO(penelopeysm): what about Meta.isexpr(expr, :$)?
if Meta.isexpr(expr, :ref)
return get_top_level_symbol(expr.args[1])
elseif Meta.isexpr(expr, :.)
return get_top_level_symbol(expr.args[1])
else
error("unreachable")
end
end
function generate_tilde_assume(left, right, vn)
# HACK: Because the Setfield.jl macro does not support assignment
# with multiple arguments on the LHS, we need to capture the return-values
# and then update the LHS variables one by one.
@gensym value
expr = :($left = $value)
if left isa Expr
expr = AbstractPPL.drop_escape(
Accessors.setmacro(BangBang.prefermutation, expr; overwrite=true)
expr = if left isa Expr # as opposed to Symbol
left_top_sym = get_top_level_symbol(left)
:(
$left_top_sym = $(Accessors.set)(
$left_top_sym,
$(AbstractPPL.with_mutation)($(AbstractPPL.getoptic)($vn)),
$value,
)
)
else
:($left = $value)
end

return quote
Expand Down
4 changes: 2 additions & 2 deletions src/contexts/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ in `DynamicPPL.unflatten` for more details. For non-threadsafe evaluation, Julia
of automatically promoting the types on its own. Secondly, the promotion only matters if you
are trying to directly assign into a `Vector{Float64}` with a `ForwardDiff.Dual` or similar
tracer type, for example using `xs[i] = MyDual`. This doesn't actually apply to
tilde-statements like `xs[i] ~ ...` because those use `Accessors.@set` under the hood, which
tilde-statements like `xs[i] ~ ...` because those use `Accessors.set` under the hood, which
also does the promotion for you. For the gory details, see the following issues:

- https://github.com/TuringLang/DynamicPPL.jl/issues/906 for accumulator types
Expand Down Expand Up @@ -260,7 +260,7 @@ struct VectorWithRanges{Tlink,N<:NamedTuple,T<:AbstractVector{<:Real}}
end

function _get_range_and_linked(
vr::VectorWithRanges, ::VarName{sym,typeof(identity)}
vr::VectorWithRanges, ::VarName{sym,AbstractPPL.Iden}
) where {sym}
return vr.iden_varname_ranges[sym]
end
Expand Down
1 change: 0 additions & 1 deletion src/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ module DebugUtils
using ..DynamicPPL

using Random: Random
using Accessors: Accessors
using InteractiveUtils: InteractiveUtils

using DocStringExtensions
Expand Down
4 changes: 2 additions & 2 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int)
for (vn, idx) in md.idcs
is_linked = md.is_transformed[idx]
range = md.ranges[idx] .+ (start_offset - 1)
if AbstractPPL.getoptic(vn) === identity
if AbstractPPL.getoptic(vn) isa AbstractPPL.Iden
all_iden_ranges = merge(
all_iden_ranges,
NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)),
Expand All @@ -376,7 +376,7 @@ function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int)
for (vn, idx) in vnv.varname_to_index
is_linked = vnv.is_unconstrained[idx]
range = vnv.ranges[idx] .+ (start_offset - 1)
if AbstractPPL.getoptic(vn) === identity
if AbstractPPL.getoptic(vn) isa AbstractPPL.Iden
all_iden_ranges = merge(
all_iden_ranges,
NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)),
Expand Down
Loading
Loading