diff --git a/Project.toml b/Project.toml index e02ee621e..ffe8b7660 100644 --- a/Project.toml +++ b/Project.toml @@ -48,7 +48,7 @@ DynamicPPLMooncakeExt = ["Mooncake"] [compat] ADTypes = "1" AbstractMCMC = "5.10" -AbstractPPL = "0.13.1" +AbstractPPL = "0.14" Accessors = "0.1" BangBang = "0.4.1" Bijectors = "0.15.11" @@ -76,3 +76,6 @@ Random = "1.6" Statistics = "1" Test = "1.6" julia = "1.10.8" + +[sources] +AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "py/newvarname"} diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 523889a7a..32ad336cc 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -3,6 +3,7 @@ uuid = "d94a1522-c11e-44a7-981a-42bf5dc1a001" version = "0.1.0" [deps] +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -18,9 +19,11 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" [sources] +AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "py/newvarname"} DynamicPPL = {path = "../"} [compat] +AbstractPPL = "0.14" ADTypes = "1.14.0" Chairmarks = "1.3.1" Distributions = "0.25.117" diff --git a/docs/Project.toml b/docs/Project.toml index 10a4a5c8a..c5a530ae3 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -16,7 +16,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" [compat] AbstractMCMC = "5" -AbstractPPL = "0.13" +AbstractPPL = "0.14" Accessors = "0.1" Distributions = "0.25" Documenter = "1" @@ -29,3 +29,6 @@ LogDensityProblems = "2" MCMCChains = "5, 6, 7" MarginalLogDensities = "0.4" StableRNGs = "1" + +[sources] +AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "py/newvarname"} diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index ffb5baf25..149e1bcd5 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -1,6 +1,6 @@ module DynamicPPLMarginalLogDensitiesExt -using DynamicPPL: DynamicPPL, LogDensityProblems, VarName +using DynamicPPL: DynamicPPL, LogDensityProblems, VarName, AbstractPPL using MarginalLogDensities: MarginalLogDensities # A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by @@ -105,7 +105,7 @@ function DynamicPPL.marginalize( ldf = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo) # Determine the indices for the variables to marginalise out. varindices = mapreduce(vcat, marginalized_varnames) do vn - if DynamicPPL.getoptic(vn) === identity + if AbstractPPL.getoptic(vn) isa AbstractPPL.Iden ldf._iden_varname_ranges[DynamicPPL.getsym(vn)].range else ldf._varname_ranges[vn].range diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 898b6caf9..32120c7e5 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -514,7 +514,7 @@ julia> values_as(SimpleVarInfo(data), NamedTuple) (x = 1.0, m = [2.0]) julia> values_as(SimpleVarInfo(data), OrderedDict) -OrderedDict{VarName{sym, typeof(identity)} where sym, Any} with 2 entries: +OrderedDict{VarName{sym, AbstractPPL.Iden} where sym, Any} with 2 entries: x => 1.0 m => [2.0] @@ -564,7 +564,7 @@ julia> values_as(vi, NamedTuple) (s = 1.0, m = 2.0) julia> values_as(vi, OrderedDict) -OrderedDict{VarName{sym, typeof(identity)} where sym, Float64} with 2 entries: +OrderedDict{VarName{sym, AbstractPPL.Iden} where sym, Float64} with 2 entries: s => 1.0 m => 2.0 @@ -590,7 +590,7 @@ julia> values_as(vi, NamedTuple) (s = 1.0, m = 2.0) julia> values_as(vi, OrderedDict) -OrderedDict{VarName{sym, typeof(identity)} where sym, Float64} with 2 entries: +OrderedDict{VarName{sym, AbstractPPL.Iden} where sym, Float64} with 2 entries: s => 1.0 m => 2.0 @@ -680,7 +680,7 @@ julia> # Extract one with only `m`. julia> keys(varinfo_subset1) -1-element Vector{VarName{:m, typeof(identity)}}: +1-element Vector{VarName{:m, AbstractPPL.Iden}}: m julia> varinfo_subset1[@varname(m)] diff --git a/src/compiler.jl b/src/compiler.jl index f1e92e369..52d44c68a 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1,44 +1,20 @@ const INTERNALNAMES = (:__model__, :__varinfo__) -""" - need_concretize(expr) - -Return `true` if `expr` needs to be concretized, i.e., if it contains a colon `:` or -requires a dynamic optic. - -# Examples - -```jldoctest; setup=:(using Accessors) -julia> DynamicPPL.need_concretize(:(x[1, :])) -true - -julia> DynamicPPL.need_concretize(:(x[1, end])) -true - -julia> DynamicPPL.need_concretize(:(x[1, 1])) -false -""" -function need_concretize(expr) - return Accessors.need_dynamic_optic(expr) || begin - flag = false - MacroTools.postwalk(expr) do ex - # Concretise colon by default - ex == :(:) && (flag = true) && return ex - end - flag - end +drop_escape(x) = x +function drop_escape(expr::Expr) + Meta.isexpr(expr, :escape) && return drop_escape(expr.args[1]) + return Expr(expr.head, map(x -> drop_escape(x), expr.args)...) end - """ make_varname_expression(expr) -Return a `VarName` based on `expr`, concretizing it if necessary. +Return a `VarName` based on `expr`. """ function make_varname_expression(expr) - # HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact - # that in DynamicPPL we the entire function body. Instead we should be - # more selective with our escape. Until that's the case, we remove them all. - return AbstractPPL.drop_escape(varname(expr, need_concretize(expr))) + # HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact that in + # DynamicPPL we escape the entire function body. Instead we should be more selective + # with our escape. Until that's the case, we remove them all. + return drop_escape(AbstractPPL.varname(expr, false)) end """ @@ -55,10 +31,9 @@ Let `expr` be `:(x[1])`. It is an assumption in the following cases: When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`. -If `vn` is specified, it will be assumed to refer to a expression which -evaluates to a `VarName`, and this will be used in the subsequent checks. -If `vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be -used in its place. +If `vn` is specified, it will be assumed to refer to a expression which evaluates to a +`VarName`, and this will be used in the subsequent checks. If `vn` is not specified, +`(@varname \$expr)` will be used in its place. """ function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr)) return quote @@ -221,9 +196,6 @@ variables. # Example ```jldoctest; setup=:(using Distributions, LinearAlgebra) -julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(ones(2), I), randn(2, 2), @varname(x)); vns[end] -x[:, 2] - julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x)); vns[end] x[1, 2] @@ -241,31 +213,20 @@ end function unwrap_right_left_vns(right::NamedDist, left::AbstractMatrix, ::VarName) return unwrap_right_left_vns(right.dist, left, right.name) end -function unwrap_right_left_vns( - right::MultivariateDistribution, left::AbstractMatrix, vn::VarName -) - # This an expression such as `x .~ MvNormal()` which we interpret as - # x[:, i] ~ MvNormal() - # for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`, - # and we therefore add the `Colon()` below. - vns = map(axes(left, 2)) do i - return AbstractPPL.concretize(Accessors.IndexLens((Colon(), i)) ∘ vn, left) - end - return unwrap_right_left_vns(right, left, vns) -end function unwrap_right_left_vns( right::Union{Distribution,AbstractArray{<:Distribution}}, left::AbstractArray, vn::VarName, ) vns = map(CartesianIndices(left)) do i - return Accessors.IndexLens(Tuple(i)) ∘ vn + sym, optic = getsym(vn), getoptic(vn) + return VarName{sym}(AbstractPPL.Index(Tuple(i), (;), AbstractPPL.Iden()) ∘ optic) end return unwrap_right_left_vns(right, left, vns) end resolve_varnames(vn::VarName, _) = vn -resolve_varnames(vn::VarName, dist::NamedDist) = dist.name +resolve_varnames(::VarName, dist::NamedDist) = dist.name ################# # Main Compiler # @@ -463,9 +424,18 @@ function generate_tilde_literal(left, right) end end -assign_or_set!!(lhs::Symbol, rhs) = AbstractPPL.drop_escape(:($lhs = $rhs)) -function assign_or_set!!(lhs::Expr, rhs) - return AbstractPPL.drop_escape(:($BangBang.@set!! $lhs = $rhs)) +assign_or_set!!(lhs::Symbol, rhs, vn) = drop_escape(:($lhs = $rhs)) +function assign_or_set!!(lhs::Expr, rhs, vn) + left_top_sym = get_top_level_symbol(lhs) + return drop_escape( + :( + $left_top_sym = $(Accessors.set)( + $left_top_sym, + $(AbstractPPL.with_mutation)($(AbstractPPL.getoptic)($vn)), + $rhs, + ) + ), + ) end """ @@ -487,12 +457,13 @@ function generate_tilde(left, right) $isassumption = $(DynamicPPL.isassumption(left, vn)) if $(DynamicPPL.isfixed(left, vn)) # $left may not be a simple varname, it might be x.a or x[1], in which case we - # need to use BangBang.@set!! to safely set it. + # need to use Accessors.set to safely set it. $(assign_or_set!!( left, :($(DynamicPPL.getfixed_nested)( __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) )), + vn, )) elseif $isassumption $(generate_tilde_assume(left, dist, vn)) @@ -520,22 +491,39 @@ function generate_tilde(left, right) $vn, __varinfo__, ) - $(assign_or_set!!(left, value)) + $(assign_or_set!!(left, value, vn)) $value end end end +get_top_level_symbol(expr::Symbol) = expr +function get_top_level_symbol(expr::Expr) + # TODO(penelopeysm): what about Meta.isexpr(expr, :$)? + if Meta.isexpr(expr, :ref) + return get_top_level_symbol(expr.args[1]) + elseif Meta.isexpr(expr, :.) + return get_top_level_symbol(expr.args[1]) + else + error("unreachable") + end +end function generate_tilde_assume(left, right, vn) # HACK: Because the Setfield.jl macro does not support assignment # with multiple arguments on the LHS, we need to capture the return-values # and then update the LHS variables one by one. @gensym value - expr = :($left = $value) - if left isa Expr - expr = AbstractPPL.drop_escape( - Accessors.setmacro(BangBang.prefermutation, expr; overwrite=true) + expr = if left isa Expr # as opposed to Symbol + left_top_sym = get_top_level_symbol(left) + :( + $left_top_sym = $(Accessors.set)( + $left_top_sym, + $(AbstractPPL.with_mutation)($(AbstractPPL.getoptic)($vn)), + $value, + ) ) + else + :($left = $value) end return quote diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 80a494c23..7c4db1e4f 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -60,7 +60,7 @@ in `DynamicPPL.unflatten` for more details. For non-threadsafe evaluation, Julia of automatically promoting the types on its own. Secondly, the promotion only matters if you are trying to directly assign into a `Vector{Float64}` with a `ForwardDiff.Dual` or similar tracer type, for example using `xs[i] = MyDual`. This doesn't actually apply to -tilde-statements like `xs[i] ~ ...` because those use `Accessors.@set` under the hood, which +tilde-statements like `xs[i] ~ ...` because those use `Accessors.set` under the hood, which also does the promotion for you. For the gory details, see the following issues: - https://github.com/TuringLang/DynamicPPL.jl/issues/906 for accumulator types @@ -260,7 +260,7 @@ struct VectorWithRanges{Tlink,N<:NamedTuple,T<:AbstractVector{<:Real}} end function _get_range_and_linked( - vr::VectorWithRanges, ::VarName{sym,typeof(identity)} + vr::VectorWithRanges, ::VarName{sym,AbstractPPL.Iden} ) where {sym} return vr.iden_varname_ranges[sym] end diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 8810b9819..79e625e36 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -3,7 +3,6 @@ module DebugUtils using ..DynamicPPL using Random: Random -using Accessors: Accessors using InteractiveUtils: InteractiveUtils using DocStringExtensions diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 3008a329b..ec97cef08 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -357,7 +357,7 @@ function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) for (vn, idx) in md.idcs is_linked = md.is_transformed[idx] range = md.ranges[idx] .+ (start_offset - 1) - if AbstractPPL.getoptic(vn) === identity + if AbstractPPL.getoptic(vn) isa AbstractPPL.Iden all_iden_ranges = merge( all_iden_ranges, NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), @@ -376,7 +376,7 @@ function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) for (vn, idx) in vnv.varname_to_index is_linked = vnv.is_unconstrained[idx] range = vnv.ranges[idx] .+ (start_offset - 1) - if AbstractPPL.getoptic(vn) === identity + if AbstractPPL.getoptic(vn) isa AbstractPPL.Iden all_iden_ranges = merge( all_iden_ranges, NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), diff --git a/src/model.jl b/src/model.jl index 8bfeaf6a1..5637c799c 100644 --- a/src/model.jl +++ b/src/model.jl @@ -501,19 +501,19 @@ true julia> # Since we conditioned on `a.m`, it is not treated as a random variable. # However, `a.x` will still be a random variable. keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: +1-element Vector{VarName{:a, AbstractPPL.Property{:x, AbstractPPL.Iden}}}: a.x julia> # We can also condition on `a.m` _outside_ of the PrefixContext: cm = condition(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0)); julia> conditioned(cm) -Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: +Dict{VarName{:a, AbstractPPL.Property{:m, AbstractPPL.Iden}}, Float64} with 1 entry: a.m => 1.0 julia> # Now `a.x` will be sampled. keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: +1-element Vector{VarName{:a, AbstractPPL.Property{:x, AbstractPPL.Iden}}}: a.x ``` """ @@ -833,25 +833,25 @@ julia> # Returns all the variables we have fixed on + their values. (x = 100.0, m = 1.0) julia> # The rest of this is the same as the `condition` example above. - cm = fix(contextualize(m, PrefixContext(@varname(a), fix(m=1.0))), x=100.0); + fm = fix(contextualize(m, PrefixContext(@varname(a), fix(m=1.0))), x=100.0); -julia> Set(keys(fixed(cm))) == Set([@varname(a.m), @varname(x)]) +julia> Set(keys(fixed(fm))) == Set([@varname(a.m), @varname(x)]) true -julia> keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: +julia> keys(VarInfo(fm)) +1-element Vector{VarName{:a, AbstractPPL.Property{:x, AbstractPPL.Iden}}}: a.x -julia> # We can also condition on `a.m` _outside_ of the PrefixContext: - cm = fix(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0)); +julia> # We can also fix `a.m` _outside_ of the PrefixContext: + fm = fix(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0)); -julia> fixed(cm) -Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: +julia> fixed(fm) +Dict{VarName{:a, AbstractPPL.Property{:m, AbstractPPL.Iden}}, Float64} with 1 entry: a.m => 1.0 julia> # Now `a.x` will be sampled. - keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: + keys(VarInfo(fm)) +1-element Vector{VarName{:a, AbstractPPL.Property{:x, AbstractPPL.Iden}}}: a.x ``` """ diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 9d3fb1925..3889f3025 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -154,11 +154,11 @@ julia> svi_nt[@varname(m.a[1])] 1.0 julia> svi_nt[@varname(m.a[2])] -ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] +ERROR: m.a[2] was not found in the NamedTuple provided [...] julia> svi_nt[@varname(m.b)] -ERROR: FieldError: type NamedTuple has no field `b`, available fields: `a` +ERROR: m.b was not found in the NamedTuple provided [...] ``` @@ -327,7 +327,7 @@ Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(Base.Fix1(getinde Base.getindex(svi::SimpleVarInfo, ::Colon) = values_as(svi, Vector) -getindex_internal(vi::SimpleVarInfo, vn::VarName) = get(vi.values, vn) +getindex_internal(vi::SimpleVarInfo, vn::VarName) = AbstractPPL.getvalue(vi.values, vn) # `AbstractDict` function getindex_internal( vi::SimpleVarInfo{<:Union{AbstractDict,VarNamedVector}}, vn::VarName @@ -354,28 +354,36 @@ end function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName) # For dictlike objects, we treat the entire `vn` as a _key_ to set. dict = values_as(vi) - # Attempt to split into `parent` and `child` optic. - parent, child, issuccess = splitoptic(getoptic(vn)) do optic - o = optic === nothing ? identity : optic - haskey(dict, VarName{getsym(vn)}(o)) - end - # When combined with `VarInfo`, `nothing` is equivalent to `identity`. - keyoptic = parent === nothing ? identity : parent - dict_new = if !issuccess - # Split doesn't exist ⟹ we're working with a new key. - BangBang.setindex!!(dict, val, vn) - else - # Split exists ⟹ trying to set an existing key. - vn_key = VarName{getsym(vn)}(keyoptic) - BangBang.setindex!!(dict, set!!(dict[vn_key], child, val), vn_key) + dict_new = dict + test_vn = vn + test_optic = AbstractPPL.Iden() + found = false + while true + if haskey(dict, test_vn) && AbstractPPL.canview(dict[test_vn], test_optic) + found = true + new_value = Accessors.set(dict[test_vn], test_optic, val) + dict_new = BangBang.setindex!!(dict, new_value, test_vn) + break + end + # Split the last optic off from the VarName and try again. + test_vn_optic = getoptic(test_vn) + if test_vn_optic isa AbstractPPL.Iden + # Ran out of options. + break + end + test_vn = VarName{getsym(test_vn)}(AbstractPPL.oinit(test_vn_optic)) + test_optic = test_optic ∘ AbstractPPL.olast(test_vn_optic) + end + if !found + dict_new = BangBang.setindex!!(dict, val, vn) end return Accessors.@set vi.values = dict_new end # `NamedTuple` function BangBang.push!!( - vi::SimpleVarInfo{<:NamedTuple}, ::VarName{sym,typeof(identity)}, value, ::Distribution + vi::SimpleVarInfo{<:NamedTuple}, ::VarName{sym,AbstractPPL.Iden}, value, ::Distribution ) where {sym} return Accessors.@set vi.values = merge(vi.values, NamedTuple{(sym,)}((value,))) end @@ -432,7 +440,7 @@ end function _subset(x::NamedTuple, vns) # NOTE: Here we can only handle `vns` that contain `identity` as optic. - if any(Base.Fix1(!==, identity) ∘ getoptic, vns) + if any(vn -> !(getoptic(vn) isa AbstractPPL.Iden), vns) throw( ArgumentError( "Cannot subset `NamedTuple` with non-`identity` `VarName`. " * diff --git a/src/test_utils.jl b/src/test_utils.jl index f584055b3..9f5f018cf 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -2,6 +2,7 @@ module TestUtils using AbstractMCMC using DynamicPPL +using AbstractPPL: AbstractPPL using LinearAlgebra using Distributions using Test diff --git a/src/test_utils/model_interface.jl b/src/test_utils/model_interface.jl index e7fb16fbe..fd7f7f553 100644 --- a/src/test_utils/model_interface.jl +++ b/src/test_utils/model_interface.jl @@ -104,7 +104,7 @@ Return a `NamedTuple` compatible with `varnames(model)` where the values represe the posterior mean under `model`. "Compatible" means that a `varname` from `varnames(model)` can be used to extract the -corresponding value using `get`, e.g. `get(posterior_mean(model), varname)`. +corresponding value using e.g. `AbstractPPL.getvalue(posterior_mean(model), varname)`. """ function posterior_mean end diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 6483b29e8..7b2971594 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -10,7 +10,7 @@ Test that `vi[vn]` corresponds to the correct value in `vals` for every `vn` in """ function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; compare=isequal, kwargs...) for vn in vns - val = get(vals, vn) + val = AbstractPPL.getvalue(vals, vn) # TODO(mhauru) Workaround for https://github.com/JuliaLang/LinearAlgebra.jl/pull/1404 # Remove once the fix is all Julia versions we support. if val isa Cholesky diff --git a/src/utils.jl b/src/utils.jl index 75fb805dc..2c6b38349 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -560,200 +560,12 @@ collect_maybe(x::AbstractArray) = x ####################### # BangBang.jl related # ####################### -function set!!(obj, optic::AbstractPPL.ALLOWED_OPTICS, value) +function set!!(obj, optic::AbstractPPL.AbstractOptic, value) opticmut = BangBang.prefermutation(optic) return Accessors.set(obj, opticmut, value) end -function set!!(obj, vn::VarName{sym}, value) where {sym} - optic = BangBang.prefermutation( - AbstractPPL.getoptic(vn) ∘ Accessors.PropertyLens{sym}() - ) - return Accessors.set(obj, optic, value) -end - -############################# -# AbstractPPL.jl extensions # -############################# -# This is preferable to `haskey` because the order of arguments is different, and -# we're more likely to specialize on the key in these settings rather than the container. -# TODO: I'm not sure about this name. -""" - canview(optic, container) - -Return `true` if `optic` can be used to view `container`, and `false` otherwise. - -# Examples -```jldoctest; setup=:(using Accessors; using DynamicPPL: canview) -julia> canview(@o(_.a), (a = 1.0, )) -true - -julia> canview(@o(_.a), (b = 1.0, )) # property `a` does not exist -false - -julia> canview(@o(_.a[1]), (a = [1.0, 2.0], )) -true - -julia> canview(@o(_.a[3]), (a = [1.0, 2.0], )) # out of bounds -false -``` -""" -canview(optic, container) = false -canview(::typeof(identity), _) = true -function canview(optic::Accessors.PropertyLens{field}, x) where {field} - return hasproperty(x, field) -end - -# `IndexLens`: only relevant if `x` supports indexing. -canview(optic::Accessors.IndexLens, x) = false -function canview(optic::Accessors.IndexLens, x::AbstractArray) - return checkbounds(Bool, x, optic.indices...) -end - -# `ComposedOptic`: check that we can view `.inner` and `.outer`, but using -# value extracted using `.inner`. -function canview(optic::Accessors.ComposedOptic, x) - return canview(optic.inner, x) && canview(optic.outer, optic.inner(x)) -end - -""" - parent(vn::VarName) - -Return the parent `VarName`. - -# Examples -```julia-repl; setup=:(using DynamicPPL: parent) -julia> parent(@varname(x.a[1])) -x.a - -julia> (parent ∘ parent)(@varname(x.a[1])) -x - -julia> (parent ∘ parent ∘ parent)(@varname(x.a[1])) -x -``` -""" -function parent(vn::VarName) - p = parent(getoptic(vn)) - return p === nothing ? VarName{getsym(vn)}(identity) : VarName{getsym(vn)}(p) -end - -""" - parent(optic) - -Return the parent optic. If `optic` doesn't have a parent, -`nothing` is returned. - -See also: [`parent_and_child`]. - -# Examples -```jldoctest; setup=:(using Accessors; using DynamicPPL: parent) -julia> parent(@o(_.a[1])) -(@o _.a) - -julia> # Parent of optic without parents results in `nothing`. - (parent ∘ parent)(@o(_.a[1])) === nothing -true -``` -""" -parent(optic::AbstractPPL.ALLOWED_OPTICS) = first(parent_and_child(optic)) - -""" - parent_and_child(optic) - -Return a 2-tuple of optics `(parent, child)` where `parent` is the -parent optic of `optic` and `child` is the child optic of `optic`. - -If `optic` does not have a parent, we return `(nothing, optic)`. - -See also: [`parent`]. - -# Examples -```jldoctest; setup=:(using Accessors; using DynamicPPL: parent_and_child) -julia> parent_and_child(@o(_.a[1])) -((@o _.a), (@o _[1])) - -julia> parent_and_child(@o(_.a)) -(nothing, (@o _.a)) -``` -""" -parent_and_child(optic::AbstractPPL.ALLOWED_OPTICS) = (nothing, optic) -function parent_and_child(optic::Accessors.ComposedOptic) - p, child = parent_and_child(optic.outer) - parent = p === nothing ? optic.inner : p ∘ optic.inner - return parent, child -end - -""" - splitoptic(condition, optic) - -Return a 3-tuple `(parent, child, issuccess)` where, if `issuccess` is `true`, -`parent` is a optic such that `condition(parent)` is `true` and `child ∘ parent == optic`. - -If `issuccess` is `false`, then no such split could be found. - -# Examples -```jldoctest; setup=:(using Accessors; using DynamicPPL: splitoptic) -julia> p, c, issucesss = splitoptic(@o(_.a[1])) do parent - # Succeeds! - parent == @o(_.a) - end -((@o _.a), (@o _[1]), true) - -julia> c ∘ p -(@o _.a[1]) - -julia> splitoptic(@o(_.a[1])) do parent - # Fails! - parent == @o(_.b) - end -(nothing, (@o _.a[1]), false) -``` -""" -function splitoptic(condition, optic) - current_parent, current_child = parent_and_child(optic) - # We stop if either a) `condition` is satisfied, or b) we reached the root. - while !condition(current_parent) && current_parent !== nothing - current_parent, c = parent_and_child(current_parent) - current_child = current_child ∘ c - end - - return current_parent, current_child, condition(current_parent) -end - -""" - remove_parent_optic(vn_parent::VarName, vn_child::VarName) - -Remove the parent optic `vn_parent` from `vn_child`. - -# Examples -```jldoctest; setup = :(using Accessors; using DynamicPPL: remove_parent_optic) -julia> remove_parent_optic(@varname(x), @varname(x.a)) -(@o _.a) - -julia> remove_parent_optic(@varname(x), @varname(x.a[1])) -(@o _.a[1]) - -julia> remove_parent_optic(@varname(x.a), @varname(x.a[1])) -(@o _[1]) - -julia> remove_parent_optic(@varname(x.a), @varname(x.a[1].b)) -(@o _[1].b) - -julia> remove_parent_optic(@varname(x.a), @varname(x.a)) -ERROR: Could not find x.a in x.a - -julia> remove_parent_optic(@varname(x.a[2]), @varname(x.a[1])) -ERROR: Could not find x.a[2] in x.a[1] -``` -""" -function remove_parent_optic(vn_parent::VarName{sym}, vn_child::VarName{sym}) where {sym} - _, child, issuccess = splitoptic(getoptic(vn_child)) do optic - o = optic === nothing ? identity : optic - o == getoptic(vn_parent) - end - - issuccess || error("Could not find $vn_parent in $vn_child") - return child +function set!!(obj, vn::VarName, value) + return set!!(obj, AbstractPPL.varname_to_optic(vn), value) end # HACK(torfjelde): This makes it so it works on iterators, etc. by default. @@ -804,7 +616,7 @@ Return instance similar to `vi` but with `vns` set to values from `vals`. """ function update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns) for vn in vns - vi = DynamicPPL.setindex!!(vi, get(vals, vn), vn) + vi = DynamicPPL.setindex!!(vi, AbstractPPL.getvalue(vals, vn), vn) end return vi end diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 17b851d1d..8a1451a13 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -1635,27 +1635,20 @@ function hasvalue(vnv::VarNamedVector, vn::VarName) # Handle the easy case where the right symbol isn't even present. !any(k -> getsym(k) == getsym(vn), keys(vnv)) && return false + # If vn is of the form @varname(somesymbol[someindex]), we check whether we store + # @varname(somesymbol) and can index into it with someindex. If we rather have a + # composed optic with the last part being an index lens, we do a similar check but + # stripping out the last index lens part. If these pass, the answer is definitely + # "yes". If not, we still don't know for sure. + # TODO(mhauru) What about casese where vnv stores both @varname(x) and + # @varname(x[1]) or @varname(x.a)? Those should probably be banned, but currently + # aren't. optic = getoptic(vn) - if optic isa Accessors.IndexLens || optic isa Accessors.ComposedOptic - # If vn is of the form @varname(somesymbol[someindex]), we check whether we store - # @varname(somesymbol) and can index into it with someindex. If we rather have a - # composed optic with the last part being an index lens, we do a similar check but - # stripping out the last index lens part. If these pass, the answer is definitely - # "yes". If not, we still don't know for sure. - # TODO(mhauru) What about casese where vnv stores both @varname(x) and - # @varname(x[1]) or @varname(x.a)? Those should probably be banned, but currently - # aren't. - head, tail = if optic isa Accessors.ComposedOptic - decomp_optic = Accessors.decompose(optic) - first(decomp_optic), Accessors.compose(decomp_optic[2:end]...) - else - optic, identity - end - parent_varname = VarName{getsym(vn)}(tail) - if haskey(vnv, parent_varname) - valvec = getindex(vnv, parent_varname) - return canview(head, valvec) - end + last, init = AbstractPPL.olast(optic), AbstractPPL.oinit(optic) + parent_varname = VarName{getsym(vn)}(init) + if haskey(vnv, parent_varname) + valvec = getindex(vnv, parent_varname) + return AbstractPPL.canview(last, valvec) end throw(ErrorException("hasvalue has not been fully implemented for this VarName: $(vn)")) end @@ -1672,17 +1665,13 @@ function getvalue(vnv::VarNamedVector, vn::VarName) end optic = getoptic(vn) - # See hasvalue for some comments on the logic of this if block. - if optic isa Accessors.IndexLens || optic isa Accessors.ComposedOptic - head, tail = if optic isa Accessors.ComposedOptic - decomp_optic = Accessors.decompose(optic) - first(decomp_optic), Accessors.compose(decomp_optic[2:end]...) - else - optic, identity - end - parent_varname = VarName{getsym(vn)}(tail) + # See hasvalue for some comments on the logic of this. + optic = getoptic(vn) + last, init = AbstractPPL.olast(optic), AbstractPPL.oinit(optic) + parent_varname = VarName{getsym(vn)}(init) + if haskey(vnv, parent_varname) valvec = getindex(vnv, parent_varname) - return head(valvec) + return last(valvec) end throw(ErrorException("getvalue has not been fully implemented for this VarName: $(vn)")) end diff --git a/test/Project.toml b/test/Project.toml index 927954ba4..90863fb83 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -34,7 +34,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "1" AbstractMCMC = "5.10" -AbstractPPL = "0.13" +AbstractPPL = "0.14" Accessors = "0.1" Aqua = "0.8" BangBang = "0.4" @@ -58,3 +58,6 @@ SpecialFunctions = "2.6.1" StableRNGs = "1" Zygote = "0.6, 0.7" julia = "1.10" + +[sources] +AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "py/newvarname"} diff --git a/test/contexts.jl b/test/contexts.jl index cdd32f379..3827e3ec2 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,4 +1,4 @@ -using Test, DynamicPPL, Accessors +using Test, DynamicPPL using AbstractPPL: getoptic, hasvalue, getvalue using DynamicPPL: leafcontext, diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 011cb22ce..d12828271 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -36,7 +36,7 @@ using Mooncake: Mooncake for vn in keys(vi) # Check that `getindex_internal` returns the same thing as using the ranges # directly - range_with_linked = if AbstractPPL.getoptic(vn) === identity + range_with_linked = if AbstractPPL.getoptic(vn) isa AbstractPPL.Iden nt_ranges[AbstractPPL.getsym(vn)] else dict_ranges[vn] diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 780d45b46..fde807dda 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -5,7 +5,7 @@ # Instantiate a `VarInfo` with the example values. vi = VarInfo(model) for vn in DynamicPPL.TestUtils.varnames(model) - vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) + vi = DynamicPPL.setindex!!(vi, AbstractPPL.getvalue(example_values, vn), vn) end loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true( diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 42e377440..acbb43d93 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -105,7 +105,9 @@ continue end for vn in DynamicPPL.TestUtils.varnames(model) - vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) + vi = DynamicPPL.setindex!!( + vi, AbstractPPL.getvalue(values_constrained, vn), vn + ) end vi = last(DynamicPPL.evaluate!!(model, vi)) @@ -136,7 +138,7 @@ # Should result in same values. @test all( DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_invlinked, vn)) ≈ - DynamicPPL.tovec(get(values_constrained, vn)) for + DynamicPPL.tovec(AbstractPPL.getvalue(values_constrained, vn)) for vn in DynamicPPL.TestUtils.varnames(model) ) end @@ -173,7 +175,7 @@ # Realization for `m` should be different wp. 1. for vn in DynamicPPL.TestUtils.varnames(model) - @test svi_new[vn] != get(retval, vn) + @test svi_new[vn] != AbstractPPL.getvalue(retval, vn) end # Logjoint should be non-zero wp. 1. @@ -209,7 +211,9 @@ # Update the realizations in `svi_new`. svi_eval = svi_new for vn in DynamicPPL.TestUtils.varnames(model) - svi_eval = DynamicPPL.setindex!!(svi_eval, get(values_eval, vn), vn) + svi_eval = DynamicPPL.setindex!!( + svi_eval, AbstractPPL.getvalue(values_eval, vn), vn + ) end # Reset the logp accumulators. @@ -225,7 +229,7 @@ # TODO(mhauru) Workaround for # https://github.com/JuliaLang/LinearAlgebra.jl/pull/1404 # Remove once the fix is all Julia versions we support. - val = get(values_eval, vn) + val = AbstractPPL.getvalue(values_eval, vn) if val isa Cholesky @test svi_eval[vn].L == val.L else @@ -267,7 +271,8 @@ # Realizations from model should all be equal to the unconstrained realization. for vn in DynamicPPL.TestUtils.varnames(model) - @test get(retval_unconstrained, vn) ≈ svi[vn] rtol = 1e-6 + @test AbstractPPL.getvalue(retval_unconstrained, vn) ≈ svi[vn] rtol = + 1e-6 end # `getlogp` should be equal to the logjoint with log-absdet-jac correction. diff --git a/test/varinfo.jl b/test/varinfo.jl index a7948cc32..eeb74f3c0 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -480,7 +480,7 @@ end @test getindex(vals, Symbol(vn)) == getindex(vi, vn) else # Assumed to be of form `(m = [1.0, ...], ...)`. - @test get(vals, vn) == getindex(vi, vn) + @test AbstractPPL.getvalue(vals, vn) == getindex(vi, vn) end end end