- Notifications
You must be signed in to change notification settings - Fork19
Implementation of variational Bayes inference algorithms
License
TuringLang/AdvancedVI.jl
Folders and files
| Name | Name | Last commit message | Last commit date | |
|---|---|---|---|---|
Repository files navigation
| AD Backend | Integration Status |
|---|---|
| ForwardDiff | |
| ReverseDiff | |
| Zygote | |
| Mooncake | |
| Enzyme |
AdvancedVI provides implementations of variational inference (VI) algorithms, which is a family of algorithms aiming for scalable approximate Bayesian inference by leveraging optimization.AdvancedVI is part of theTuring probabilistic programming ecosystem.The purpose of this package is to provide a common accessible interface for various VI algorithms and utilities so that other packages, e.g.Turing, only need to write a light wrapper for integration.For example, integratingTuring withAdvancedVI.ADVI only involves converting aTuring.Model into aLogDensityProblem and extracting a correspondingBijectors.bijector.
We will describe a simple example to demonstrate the basic usage ofAdvancedVI.AdvancedVI works with differentiable models specified through theLogDensityProblem interface.Let's look at a basic logistic regression example with a hierarchical prior.For a dataset
TheLogDensityProblem corresponding to this model can be constructed as
using LogDensityProblems: LogDensityProblemsusing Distributionsusing FillArraysstruct LogReg{XType,YType} X::XType y::YTypeendfunction LogDensityProblems.logdensity(model::LogReg, θ) (; X, y)= model d=size(X,2) β, σ= θ[1:size(X,2)], θ[end] logprior_β=logpdf(MvNormal(Zeros(d), σ), β) logprior_σ=logpdf(LogNormal(0,3), σ) logit= X*β loglike_y=mapreduce((li, yi)->logpdf(BernoulliLogit(li), yi),+, logit, y)return loglike_y+ logprior_β+ logprior_σendfunction LogDensityProblems.dimension(model::LogReg)returnsize(model.X,2)+1endfunction LogDensityProblems.capabilities(::Type{<:LogReg})return LogDensityProblems.LogDensityOrder{0}()end;
Since the support ofσ is constrained to be positive and most VI algorithms assume an unconstrained Euclidean support, we need to use abijector to transformθ.We will useBijectors for this purpose.The bijector corresponding to the joint support of our model can be constructed as follows:
using Bijectors: Bijectorsfunction Bijectors.bijector(model::LogReg) d=size(model.X,2)return Bijectors.Stacked( Bijectors.bijector.([MvNormal(Zeros(d),1.0),LogNormal(0,3)]), [1:d, (d+1):(d+1)], )end;
A simpler approach would be to useTuring, where aTuring.Model can be automatically be converted into aLogDensityProblem and a correspondingbijector is automatically generated.
Since most VI algorithms assume that the posterior is unconstrained, we will apply a change-of-variable to our model to make it unconstrained.This amounts to wrapping it into aLogDensityProblem that applies the transformation and the corresponding Jacobian adjustment.
struct TransformedLogDensityProblem{Prob,BInv} prob::Prob binv::BInvendfunctionTransformedLogDensityProblem(prob) b= Bijectors.bijector(prob) binv= Bijectors.inverse(b)returnTransformedLogDensityProblem{typeof(prob),typeof(binv)}(prob, binv)endfunction LogDensityProblems.logdensity(prob_trans::TransformedLogDensityProblem, θ_trans) (; prob, binv)= prob_trans θ, logabsdetjac= Bijectors.with_logabsdet_jacobian(binv, θ_trans)return LogDensityProblems.logdensity(prob, θ)+ logabsdetjacendfunction LogDensityProblems.dimension(prob_trans::TransformedLogDensityProblem) (; prob, binv)= prob_trans b= Bijectors.inverse(binv) d= LogDensityProblems.dimension(prob)returnprod(Bijectors.output_size(b, (d,)))endfunction LogDensityProblems.capabilities(::Type{TransformedLogDensityProblem{Prob,BInv}})where {Prob,BInv}return LogDensityProblems.capabilities(Prob)end;
For the dataset, we will use the popularsonar classification dataset from the UCI repository.This can be automatically downloaded usingOpenML.The sonar dataset corresponds to the dataset id 40.
using OpenML: OpenMLusing DataFrames: DataFramesdata=Array(DataFrames.DataFrame(OpenML.load(40)))X=Matrix{Float64}(data[:,1:(end-1)])y=Vector{Bool}(data[:,end].=="Mine");
Let's apply some basic pre-processing and add an intercept column:
using StatisticsX= (X.-mean(X; dims=2))./std(X; dims=2)X=hcat(X,ones(size(X,1)));
The model can now be instantiated as follows:
prob=LogReg(X, y);prob_trans=TransformedLogDensityProblem(prob)
For the VI algorithm, we will useKLMinRepGradDescent:
using ADTypes, ReverseDiffusing AdvancedVIalg=KLMinRepGradDescent(ADTypes.AutoReverseDiff(); operator=ClipScale())
This algorithm minimizes the exclusive/reverse KL divergence via stochastic gradient descent in the (Euclidean) space of the parameters of the variational approximation with the reparametrization gradient123.This is also commonly referred as automatic differentiation VI, black-box VI, stochastic gradient VI, and so on.
Also, projection or proximal operators can be used through the keyword argumentoperator.For this example, we will use Gaussian variational family, which is part of the more broad location-scale family.These require the scale matrix to have strictly positive eigenvalues at all times.Here, the projection operatorClipScale ensures this.
ThisKLMinRepGradDescent, in particular, assumes that the targetLogDensityProblem has gradients.For this, it is straightforward to useLogDensityProblemsAD:
using DifferentiationInterface: DifferentiationInterfaceusing LogDensityProblemsAD: LogDensityProblemsADprob_trans_ad= LogDensityProblemsAD.ADgradient(ADTypes.AutoReverseDiff(), prob_trans);
For the variational family, we will consider aFullRankGaussian approximation:
using LinearAlgebrad= LogDensityProblems.dimension(prob_trans_ad)q=FullRankGaussian(zeros(d),LowerTriangular(Matrix{Float64}(0.6*I, d, d)))q=MeanFieldGaussian(zeros(d),Diagonal(ones(d)));
We can now run VI:
max_iter=10^3q_opt, info, _= AdvancedVI.optimize(alg, max_iter, prob_trans_ad, q);
Recall that we applied a change-of-variable to the posterior to make it unconstrained.This, however, is not the original constrained posterior that we wanted to approximate.Therefore, we finally need to apply a change-of-variable toq_opt to make it approximate our original problem.
b= Bijectors.bijector(prob)binv= Bijectors.inverse(b)q_trans= Bijectors.TransformedDistribution(q_opt, binv)
For more examples and details, please refer to the documentation.
Footnotes
Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. InInternational Conference on Machine Learning. PMLR.↩
Rezende, D. J., Mohamed, S., & Wierstra, D. (2014, June). Stochastic backpropagation and approximate inference in deep generative models. InInternational Conference on Machine Learning. PMLR.↩
Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. InInternational Conference on Learning Representations.↩
About
Implementation of variational Bayes inference algorithms
Resources
License
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Packages0
Uh oh!
There was an error while loading.Please reload this page.