Probability distributions - torch.distributions#
Created On: Oct 19, 2017 | Last Updated On: Jun 13, 2025
Thedistributions package contains parameterizable probability distributionsand sampling functions. This allows the construction of stochastic computationgraphs and stochastic gradient estimators for optimization. This packagegenerally follows the design of theTensorFlow Distributions package.
It is not possible to directly backpropagate through random samples. However,there are two main methods for creating surrogate functions that can bebackpropagated through. These are the score function estimator/likelihood ratioestimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonlyseen as the basis for policy gradient methods in reinforcement learning, and thepathwise derivative estimator is commonly seen in the reparameterization trickin variational autoencoders. Whilst the score function only requires the valueof samples, the pathwise derivative requires the derivative. The next sections discuss these two in a reinforcement learningexample. For more details seeGradient Estimation Using Stochastic Computation Graphs .
Score function#
When the probability density function is differentiable with respect to itsparameters, we only needsample() andlog_prob() to implement REINFORCE:
where are the parameters, is the learning rate, is the reward and is the probability oftaking action in state given policy.
In practice we would sample an action from the output of a network, apply thisaction in an environment, and then uselog_prob to construct an equivalentloss function. Note that we use a negative because optimizers use gradientdescent, whilst the rule above assumes gradient ascent. With a categoricalpolicy, the code for implementing REINFORCE would be as follows:
probs=policy_network(state)# Note that this is equivalent to what used to be called multinomialm=Categorical(probs)action=m.sample()next_state,reward=env.step(action)loss=-m.log_prob(action)*rewardloss.backward()
Pathwise derivative#
The other way to implement these stochastic/policy gradients would be to use thereparameterization trick from thersample() method, where theparameterized random variable can be constructed via a parameterizeddeterministic function of a parameter-free random variable. The reparameterizedsample therefore becomes differentiable. The code for implementing the pathwisederivative would be as follows:
params=policy_network(state)m=Normal(*params)# Any distribution with .has_rsample == True could work based on the applicationaction=m.rsample()next_state,reward=env.step(action)# Assuming that reward is differentiableloss=-rewardloss.backward()
Distribution#
- classtorch.distributions.distribution.Distribution(batch_shape=(),event_shape=(),validate_args=None)[source]#
Bases:
objectDistribution is the abstract base class for probability distributions.
- Parameters:
batch_shape (torch.Size) – The shape over which parameters are batched.
event_shape (torch.Size) – The shape of a single sample (without batching).
validate_args (bool,optional) – Whether to validate arguments. Default: None.
- propertyarg_constraints:dict[str,Constraint]#
Returns a dictionary from argument names to
Constraintobjects thatshould be satisfied by each argument of this distribution. Args thatare not tensors need not appear in this dict.
- entropy()[source]#
Returns entropy of distribution, batched over batch_shape.
- Returns:
Tensor of shape batch_shape.
- Return type:
- enumerate_support(expand=True)[source]#
Returns tensor containing all values supported by a discretedistribution. The result will enumerate over dimension 0, so the shapeof the result will be(cardinality,) + batch_shape + event_shape(whereevent_shape = () for univariate distributions).
Note that this enumerates over all batched tensors in lock-step[[0, 0], [1, 1], …]. Withexpand=False, enumeration happensalong dim 0, but with the remaining batch dimensions beingsingleton dimensions,[[0], [1], ...
To iterate over the full Cartesian product useitertools.product(m.enumerate_support()).
- expand(batch_shape,_instance=None)[source]#
Returns a new distribution instance (or populates an existing instanceprovided by a derived class) with batch dimensions expanded tobatch_shape. This method calls
expandonthe distribution’s parameters. As such, this does not allocate newmemory for the expanded distribution instance. Additionally,this does not repeat any args checking or parameter broadcasting in__init__.py, when an instance is first created.- Parameters:
batch_shape (torch.Size) – the desired expanded size.
_instance – new instance provided by subclasses thatneed to override.expand.
- Returns:
New distribution instance with batch dimensions expanded tobatch_size.
- perplexity()[source]#
Returns perplexity of distribution, batched over batch_shape.
- Returns:
Tensor of shape batch_shape.
- Return type:
- rsample(sample_shape=())[source]#
Generates a sample_shape shaped reparameterized sample or sample_shapeshaped batch of reparameterized samples if the distribution parametersare batched.
- Return type:
- sample(sample_shape=())[source]#
Generates a sample_shape shaped sample or sample_shape shaped batch ofsamples if the distribution parameters are batched.
- Return type:
- sample_n(n)[source]#
Generates n samples or n batches of samples if the distributionparameters are batched.
- Return type:
- staticset_default_validate_args(value)[source]#
Sets whether validation is enabled or disabled.
The default behavior mimics Python’s
assertstatement: validationis on by default, but is disabled if Python is run in optimized mode(viapython-O). Validation may be expensive, so you may want todisable it once a model is working.- Parameters:
value (bool) – Whether to enable validation.
- propertysupport:Constraint|None#
Returns a
Constraintobjectrepresenting this distribution’s support.
ExponentialFamily#
- classtorch.distributions.exp_family.ExponentialFamily(batch_shape=(),event_shape=(),validate_args=None)[source]#
Bases:
DistributionExponentialFamily is the abstract base class for probability distributions belonging to anexponential family, whose probability mass/density function has the form is defined below
where denotes the natural parameters, denotes the sufficient statistic, is the log normalizer function for a given family and is the carriermeasure.
Note
This class is an intermediary between theDistribution class and distributions which belongto an exponential family mainly to check the correctness of the.entropy() and analytic KLdivergence methods. We use this class to compute the entropy and KL divergence using the ADframework and Bregman divergences (courtesy of: Frank Nielsen and Richard Nock, Entropies andCross-entropies of Exponential Families).
Bernoulli#
- classtorch.distributions.bernoulli.Bernoulli(probs=None,logits=None,validate_args=None)[source]#
Bases:
ExponentialFamilyCreates a Bernoulli distribution parameterized by
probsorlogits(but not both).Samples are binary (0 or 1). They take the value1 with probabilitypand0 with probability1 - p.
Example:
>>>m=Bernoulli(torch.tensor([0.3]))>>>m.sample()# 30% chance 1; 70% chance 0tensor([ 0.])
- Parameters:
- arg_constraints={'logits':Real(),'probs':Interval(lower_bound=0.0,upper_bound=1.0)}#
- has_enumerate_support=True#
- support=Boolean()#
Beta#
- classtorch.distributions.beta.Beta(concentration1,concentration0,validate_args=None)[source]#
Bases:
ExponentialFamilyBeta distribution parameterized by
concentration1andconcentration0.Example:
>>>m=Beta(torch.tensor([0.5]),torch.tensor([0.5]))>>>m.sample()# Beta distributed with concentration concentration1 and concentration0tensor([ 0.1046])
- Parameters:
- arg_constraints={'concentration0':GreaterThan(lower_bound=0.0),'concentration1':GreaterThan(lower_bound=0.0)}#
- has_rsample=True#
- support=Interval(lower_bound=0.0,upper_bound=1.0)#
Binomial#
- classtorch.distributions.binomial.Binomial(total_count=1,probs=None,logits=None,validate_args=None)[source]#
Bases:
DistributionCreates a Binomial distribution parameterized by
total_countandeitherprobsorlogits(but not both).total_countmust bebroadcastable withprobs/logits.Example:
>>>m=Binomial(100,torch.tensor([0,.2,.8,1]))>>>x=m.sample()tensor([ 0., 22., 71., 100.])>>>m=Binomial(torch.tensor([[5.],[10.]]),torch.tensor([0.5,0.8]))>>>x=m.sample()tensor([[ 4., 5.], [ 7., 6.]])
- Parameters:
- arg_constraints={'logits':Real(),'probs':Interval(lower_bound=0.0,upper_bound=1.0),'total_count':IntegerGreaterThan(lower_bound=0)}#
- has_enumerate_support=True#
- propertysupport#
- Return type:
_DependentProperty
Categorical#
- classtorch.distributions.categorical.Categorical(probs=None,logits=None,validate_args=None)[source]#
Bases:
DistributionCreates a categorical distribution parameterized by either
probsorlogits(but not both).Note
It is equivalent to the distribution that
torch.multinomial()samples from.Samples are integers from whereK is
probs.size(-1).Ifprobs is 1-dimensional with length-K, each element is the relative probabilityof sampling the class at that index.
Ifprobs is N-dimensional, the first N-1 dimensions are treated as a batch ofrelative probability vectors.
Note
Theprobs argument must be non-negative, finite and have a non-zero sum,and it will be normalized to sum to 1 along the last dimension.
probswill return this normalized value.Thelogits argument will be interpreted as unnormalized log probabilitiesand can therefore be any real number. It will likewise be normalized so thatthe resulting probabilities sum to 1 along the last dimension.logitswill return this normalized value.See also:
torch.multinomial()Example:
>>>m=Categorical(torch.tensor([0.25,0.25,0.25,0.25]))>>>m.sample()# equal probability of 0, 1, 2, 3tensor(3)
- Parameters:
- arg_constraints={'logits':IndependentConstraint(Real(),1),'probs':Simplex()}#
- has_enumerate_support=True#
- propertysupport#
- Return type:
_DependentProperty
Cauchy#
- classtorch.distributions.cauchy.Cauchy(loc,scale,validate_args=None)[source]#
Bases:
DistributionSamples from a Cauchy (Lorentz) distribution. The distribution of the ratio ofindependent normally distributed random variables with means0 follows aCauchy distribution.
Example:
>>>m=Cauchy(torch.tensor([0.0]),torch.tensor([1.0]))>>>m.sample()# sample from a Cauchy distribution with loc=0 and scale=1tensor([ 2.3214])
- Parameters:
- arg_constraints={'loc':Real(),'scale':GreaterThan(lower_bound=0.0)}#
- has_rsample=True#
- support=Real()#
Chi2#
- classtorch.distributions.chi2.Chi2(df,validate_args=None)[source]#
Bases:
GammaCreates a Chi-squared distribution parameterized by shape parameter
df.This is exactly equivalent toGamma(alpha=0.5*df,beta=0.5)Example:
>>>m=Chi2(torch.tensor([1.0]))>>>m.sample()# Chi2 distributed with shape df=1tensor([ 0.1046])
- arg_constraints={'df':GreaterThan(lower_bound=0.0)}#
ContinuousBernoulli#
- classtorch.distributions.continuous_bernoulli.ContinuousBernoulli(probs=None,logits=None,lims=(0.499,0.501),validate_args=None)[source]#
Bases:
ExponentialFamilyCreates a continuous Bernoulli distribution parameterized by
probsorlogits(but not both).The distribution is supported in [0, 1] and parameterized by ‘probs’ (in(0,1)) or ‘logits’ (real-valued). Note that, unlike the Bernoulli, ‘probs’does not correspond to a probability and ‘logits’ does not correspond tolog-odds, but the same names are used due to the similarity with theBernoulli. See [1] for more details.
Example:
>>>m=ContinuousBernoulli(torch.tensor([0.3]))>>>m.sample()tensor([ 0.2538])
- Parameters:
[1] The continuous Bernoulli: fixing a pervasive error in variationalautoencoders, Loaiza-Ganem G and Cunningham JP, NeurIPS 2019.https://arxiv.org/abs/1907.06845
- arg_constraints={'logits':Real(),'probs':Interval(lower_bound=0.0,upper_bound=1.0)}#
- has_rsample=True#
- support=Interval(lower_bound=0.0,upper_bound=1.0)#
Dirichlet#
- classtorch.distributions.dirichlet.Dirichlet(concentration,validate_args=None)[source]#
Bases:
ExponentialFamilyCreates a Dirichlet distribution parameterized by concentration
concentration.Example:
>>>m=Dirichlet(torch.tensor([0.5,0.5]))>>>m.sample()# Dirichlet distributed with concentration [0.5, 0.5]tensor([ 0.1046, 0.8954])
- Parameters:
concentration (Tensor) – concentration parameter of the distribution(often referred to as alpha)
- arg_constraints={'concentration':IndependentConstraint(GreaterThan(lower_bound=0.0),1)}#
- has_rsample=True#
- support=Simplex()#
Exponential#
- classtorch.distributions.exponential.Exponential(rate,validate_args=None)[source]#
Bases:
ExponentialFamilyCreates a Exponential distribution parameterized by
rate.Example:
>>>m=Exponential(torch.tensor([1.0]))>>>m.sample()# Exponential distributed with rate=1tensor([ 0.1046])
- arg_constraints={'rate':GreaterThan(lower_bound=0.0)}#
- has_rsample=True#
- support=GreaterThanEq(lower_bound=0.0)#
FisherSnedecor#
- classtorch.distributions.fishersnedecor.FisherSnedecor(df1,df2,validate_args=None)[source]#
Bases:
DistributionCreates a Fisher-Snedecor distribution parameterized by
df1anddf2.Example:
>>>m=FisherSnedecor(torch.tensor([1.0]),torch.tensor([2.0]))>>>m.sample()# Fisher-Snedecor-distributed with df1=1 and df2=2tensor([ 0.2453])
- Parameters:
- arg_constraints={'df1':GreaterThan(lower_bound=0.0),'df2':GreaterThan(lower_bound=0.0)}#
- has_rsample=True#
- support=GreaterThan(lower_bound=0.0)#
Gamma#
- classtorch.distributions.gamma.Gamma(concentration,rate,validate_args=None)[source]#
Bases:
ExponentialFamilyCreates a Gamma distribution parameterized by shape
concentrationandrate.Example:
>>>m=Gamma(torch.tensor([1.0]),torch.tensor([1.0]))>>>m.sample()# Gamma distributed with concentration=1 and rate=1tensor([ 0.1046])
- Parameters:
- arg_constraints={'concentration':GreaterThan(lower_bound=0.0),'rate':GreaterThan(lower_bound=0.0)}#
- has_rsample=True#
- support=GreaterThanEq(lower_bound=0.0)#
GeneralizedPareto#
- classtorch.distributions.generalized_pareto.GeneralizedPareto(loc,scale,concentration,validate_args=None)[source]#
Bases:
DistributionCreates a Generalized Pareto distribution parameterized by
loc,scale, andconcentration.The Generalized Pareto distribution is a family of continuous probability distributions on the real line.Special cases include Exponential (when
loc= 0,concentration= 0), Pareto (whenconcentration> 0,loc=scale/concentration), and Uniform (whenconcentration= -1).This distribution is often used to model the tails of other distributions. This implementation is based on theimplementation in TensorFlow Probability.
Example:
>>>m=GeneralizedPareto(torch.tensor([0.1]),torch.tensor([2.0]),torch.tensor([0.4]))>>>m.sample()# sample from a Generalized Pareto distribution with loc=0.1, scale=2.0, and concentration=0.4tensor([ 1.5623])
- Parameters:
- arg_constraints={'concentration':Real(),'loc':Real(),'scale':GreaterThan(lower_bound=0.0)}#
- has_rsample=True#
- propertymean#
- propertymode#
- propertysupport#
- Return type:
_DependentProperty
- propertyvariance#
Geometric#
- classtorch.distributions.geometric.Geometric(probs=None,logits=None,validate_args=None)[source]#
Bases:
DistributionCreates a Geometric distribution parameterized by
probs,whereprobsis the probability of success of Bernoulli trials.Note
torch.distributions.geometric.Geometric()-th trial is the first successhence draws samples in, whereastorch.Tensor.geometric_()k-th trial is the first success hence draws samples in.Example:
>>>m=Geometric(torch.tensor([0.3]))>>>m.sample()# underlying Bernoulli has 30% chance 1; 70% chance 0tensor([ 2.])
- Parameters:
- arg_constraints={'logits':Real(),'probs':Interval(lower_bound=0.0,upper_bound=1.0)}#
- support=IntegerGreaterThan(lower_bound=0)#
Gumbel#
- classtorch.distributions.gumbel.Gumbel(loc,scale,validate_args=None)[source]#
Bases:
TransformedDistributionSamples from a Gumbel Distribution.
Examples:
>>>m=Gumbel(torch.tensor([1.0]),torch.tensor([2.0]))>>>m.sample()# sample from Gumbel distribution with loc=1, scale=2tensor([ 1.0124])
- Parameters:
- arg_constraints:dict[str,Constraint]={'loc':Real(),'scale':GreaterThan(lower_bound=0.0)}#
- support=Real()#
HalfCauchy#
- classtorch.distributions.half_cauchy.HalfCauchy(scale,validate_args=None)[source]#
Bases:
TransformedDistributionCreates a half-Cauchy distribution parameterized byscale where:
X~Cauchy(0,scale)Y=|X|~HalfCauchy(scale)
Example:
>>>m=HalfCauchy(torch.tensor([1.0]))>>>m.sample()# half-cauchy distributed with scale=1tensor([ 2.3214])
- arg_constraints:dict[str,Constraint]={'scale':GreaterThan(lower_bound=0.0)}#
- has_rsample=True#
- support=GreaterThanEq(lower_bound=0.0)#
HalfNormal#
- classtorch.distributions.half_normal.HalfNormal(scale,validate_args=None)[source]#
Bases:
TransformedDistributionCreates a half-normal distribution parameterized byscale where:
X~Normal(0,scale)Y=|X|~HalfNormal(scale)
Example:
>>>m=HalfNormal(torch.tensor([1.0]))>>>m.sample()# half-normal distributed with scale=1tensor([ 0.1046])
- arg_constraints:dict[str,Constraint]={'scale':GreaterThan(lower_bound=0.0)}#
- has_rsample=True#
- support=GreaterThanEq(lower_bound=0.0)#
Independent#
- classtorch.distributions.independent.Independent(base_distribution,reinterpreted_batch_ndims,validate_args=None)[source]#
Bases:
Distribution,Generic[D]Reinterprets some of the batch dims of a distribution as event dims.
This is mainly useful for changing the shape of the result of
log_prob(). For example to create a diagonal Normal distribution withthe same shape as a Multivariate Normal distribution (so they areinterchangeable), you can:>>>fromtorch.distributions.multivariate_normalimportMultivariateNormal>>>fromtorch.distributions.normalimportNormal>>>loc=torch.zeros(3)>>>scale=torch.ones(3)>>>mvn=MultivariateNormal(loc,scale_tril=torch.diag(scale))>>>[mvn.batch_shape,mvn.event_shape][torch.Size([]), torch.Size([3])]>>>normal=Normal(loc,scale)>>>[normal.batch_shape,normal.event_shape][torch.Size([3]), torch.Size([])]>>>diagn=Independent(normal,1)>>>[diagn.batch_shape,diagn.event_shape][torch.Size([]), torch.Size([3])]
- Parameters:
base_distribution (torch.distributions.distribution.Distribution) – abase distribution
reinterpreted_batch_ndims (int) – the number of batch dims toreinterpret as event dims
- arg_constraints:dict[str,Constraint]={}#
- base_dist:D#
- propertysupport#
- Return type:
_DependentProperty
InverseGamma#
- classtorch.distributions.inverse_gamma.InverseGamma(concentration,rate,validate_args=None)[source]#
Bases:
TransformedDistributionCreates an inverse gamma distribution parameterized by
concentrationandratewhere:X~Gamma(concentration,rate)Y=1/X~InverseGamma(concentration,rate)
Example:
>>>m=InverseGamma(torch.tensor([2.0]),torch.tensor([3.0]))>>>m.sample()tensor([ 1.2953])
- Parameters:
- arg_constraints:dict[str,Constraint]={'concentration':GreaterThan(lower_bound=0.0),'rate':GreaterThan(lower_bound=0.0)}#
- has_rsample=True#
- support=GreaterThan(lower_bound=0.0)#
Kumaraswamy#
- classtorch.distributions.kumaraswamy.Kumaraswamy(concentration1,concentration0,validate_args=None)[source]#
Bases:
TransformedDistributionSamples from a Kumaraswamy distribution.
Example:
>>>m=Kumaraswamy(torch.tensor([1.0]),torch.tensor([1.0]))>>>m.sample()# sample from a Kumaraswamy distribution with concentration alpha=1 and beta=1tensor([ 0.1729])
- Parameters:
- arg_constraints:dict[str,Constraint]={'concentration0':GreaterThan(lower_bound=0.0),'concentration1':GreaterThan(lower_bound=0.0)}#
- has_rsample=True#
- support=Interval(lower_bound=0.0,upper_bound=1.0)#
LKJCholesky#
- classtorch.distributions.lkj_cholesky.LKJCholesky(dim,concentration=1.0,validate_args=None)[source]#
Bases:
DistributionLKJ distribution for lower Cholesky factor of correlation matrices.The distribution is controlled by
concentrationparameterto make the probability of the correlation matrix generated froma Cholesky factor proportional to. Because of that,whenconcentration==1, we have a uniform distribution over Choleskyfactors of correlation matrices:L~LKJCholesky(dim,concentration)X=L@L' ~ LKJCorr(dim, concentration)
Note that this distribution samples theCholesky factor of correlation matrices and not the correlation matricesthemselves and thereby differs slightly from the derivations in [1] fortheLKJCorr distribution. For sampling, this uses the Onion method from[1] Section 3.
Example:
>>>l=LKJCholesky(3,0.5)>>>l.sample()# l @ l.T is a sample of a correlation 3x3 matrixtensor([[ 1.0000, 0.0000, 0.0000], [ 0.3516, 0.9361, 0.0000], [-0.1899, 0.4748, 0.8593]])
- Parameters:
References
[1]Generating random correlation matrices based on vines and extended onion method (2009),Daniel Lewandowski, Dorota Kurowicka, Harry Joe.Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008
- arg_constraints={'concentration':GreaterThan(lower_bound=0.0)}#
- support=CorrCholesky()#
Laplace#
- classtorch.distributions.laplace.Laplace(loc,scale,validate_args=None)[source]#
Bases:
DistributionCreates a Laplace distribution parameterized by
locandscale.Example:
>>>m=Laplace(torch.tensor([0.0]),torch.tensor([1.0]))>>>m.sample()# Laplace distributed with loc=0, scale=1tensor([ 0.1046])
- Parameters:
- arg_constraints={'loc':Real(),'scale':GreaterThan(lower_bound=0.0)}#
- has_rsample=True#
- support=Real()#
LogNormal#
- classtorch.distributions.log_normal.LogNormal(loc,scale,validate_args=None)[source]#
Bases:
TransformedDistributionCreates a log-normal distribution parameterized by
locandscalewhere:X~Normal(loc,scale)Y=exp(X)~LogNormal(loc,scale)
Example:
>>>m=LogNormal(torch.tensor([0.0]),torch.tensor([1.0]))>>>m.sample()# log-normal distributed with mean=0 and stddev=1tensor([ 0.1046])
- Parameters:
- arg_constraints:dict[str,Constraint]={'loc':Real(),'scale':GreaterThan(lower_bound=0.0)}#
- has_rsample=True#
- support=GreaterThan(lower_bound=0.0)#
LowRankMultivariateNormal#
- classtorch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal(loc,cov_factor,cov_diag,validate_args=None)[source]#
Bases:
DistributionCreates a multivariate normal distribution with covariance matrix having a low-rank formparameterized by
cov_factorandcov_diag:covariance_matrix=cov_factor@cov_factor.T+cov_diag
Example
>>>m=LowRankMultivariateNormal(...torch.zeros(2),torch.tensor([[1.0],[0.0]]),torch.ones(2)...)>>>m.sample()# normally distributed with mean=`[0,0]`, cov_factor=`[[1],[0]]`, cov_diag=`[1,1]`tensor([-0.2102, -0.5429])
- Parameters:
Note
The computation for determinant and inverse of covariance matrix is avoided whencov_factor.shape[1] << cov_factor.shape[0] thanks toWoodbury matrix identity andmatrix determinant lemma.Thanks to these formulas, we just need to compute the determinant and inverse ofthe small size “capacitance” matrix:
capacitance=I+cov_factor.T@inv(cov_diag)@cov_factor
- arg_constraints={'cov_diag':IndependentConstraint(GreaterThan(lower_bound=0.0),1),'cov_factor':IndependentConstraint(Real(),2),'loc':IndependentConstraint(Real(),1)}#
- has_rsample=True#
- support=IndependentConstraint(Real(),1)#
MixtureSameFamily#
- classtorch.distributions.mixture_same_family.MixtureSameFamily(mixture_distribution,component_distribution,validate_args=None)[source]#
Bases:
DistributionTheMixtureSameFamily distribution implements a (batch of) mixturedistribution where all component are from different parameterizations ofthe same distribution type. It is parameterized by aCategorical“selecting distribution” (overk component) and a componentdistribution, i.e., aDistribution with a rightmost batch shape(equal to[k]) which indexes each (batch of) component.
Examples:
>>># Construct Gaussian Mixture Model in 1D consisting of 5 equally>>># weighted normal distributions>>>mix=D.Categorical(torch.ones(5,))>>>comp=D.Normal(torch.randn(5,),torch.rand(5,))>>>gmm=MixtureSameFamily(mix,comp)>>># Construct Gaussian Mixture Model in 2D consisting of 5 equally>>># weighted bivariate normal distributions>>>mix=D.Categorical(torch.ones(5,))>>>comp=D.Independent(D.Normal(...torch.randn(5,2),torch.rand(5,2)),1)>>>gmm=MixtureSameFamily(mix,comp)>>># Construct a batch of 3 Gaussian Mixture Models in 2D each>>># consisting of 5 random weighted bivariate normal distributions>>>mix=D.Categorical(torch.rand(3,5))>>>comp=D.Independent(D.Normal(...torch.randn(3,5,2),torch.rand(3,5,2)),1)>>>gmm=MixtureSameFamily(mix,comp)
- Parameters:
mixture_distribution (Categorical) –torch.distributions.Categorical-likeinstance. Manages the probability of selecting component.The number of categories must match the rightmost batchdimension of thecomponent_distribution. Must have eitherscalarbatch_shape orbatch_shape matchingcomponent_distribution.batch_shape[:-1]
component_distribution (Distribution) –torch.distributions.Distribution-likeinstance. Right-most batch dimension indexes component.
- arg_constraints:dict[str,Constraint]={}#
- propertycomponent_distribution:Distribution#
- has_rsample=False#
- propertymixture_distribution:Categorical#
- propertysupport#
- Return type:
_DependentProperty
Multinomial#
- classtorch.distributions.multinomial.Multinomial(total_count=1,probs=None,logits=None,validate_args=None)[source]#
Bases:
DistributionCreates a Multinomial distribution parameterized by
total_countandeitherprobsorlogits(but not both). The innermost dimension ofprobsindexes over categories. All other dimensions index over batches.Note that
total_countneed not be specified if onlylog_prob()iscalled (see example below)Note
Theprobs argument must be non-negative, finite and have a non-zero sum,and it will be normalized to sum to 1 along the last dimension.
probswill return this normalized value.Thelogits argument will be interpreted as unnormalized log probabilitiesand can therefore be any real number. It will likewise be normalized so thatthe resulting probabilities sum to 1 along the last dimension.logitswill return this normalized value.sample()requires a single sharedtotal_count for allparameters and samples.log_prob()allows differenttotal_count for each parameter andsample.
Example:
>>>m=Multinomial(100,torch.tensor([1.,1.,1.,1.]))>>>x=m.sample()# equal probability of 0, 1, 2, 3tensor([ 21., 24., 30., 25.])>>>Multinomial(probs=torch.tensor([1.,1.,1.,1.])).log_prob(x)tensor([-4.1338])
- Parameters:
- arg_constraints={'logits':IndependentConstraint(Real(),1),'probs':Simplex()}#
- propertysupport#
- Return type:
_DependentProperty
MultivariateNormal#
- classtorch.distributions.multivariate_normal.MultivariateNormal(loc,covariance_matrix=None,precision_matrix=None,scale_tril=None,validate_args=None)[source]#
Bases:
DistributionCreates a multivariate normal (also called Gaussian) distributionparameterized by a mean vector and a covariance matrix.
The multivariate normal distribution can be parameterized eitherin terms of a positive definite covariance matrixor a positive definite precision matrixor a lower-triangular matrix with positive-valueddiagonal entries, such that. This triangular matrixcan be obtained via e.g. Cholesky decomposition of the covariance.
Example
>>>m=MultivariateNormal(torch.zeros(2),torch.eye(2))>>>m.sample()# normally distributed with mean=`[0,0]` and covariance_matrix=`I`tensor([-0.2102, -0.5429])
- Parameters:
Note
Only one of
covariance_matrixorprecision_matrixorscale_trilcan be specified.Using
scale_trilwill be more efficient: all computations internallyare based onscale_tril. Ifcovariance_matrixorprecision_matrixis passed instead, it is only used to computethe corresponding lower triangular matrices using a Cholesky decomposition.- arg_constraints={'covariance_matrix':PositiveDefinite(),'loc':IndependentConstraint(Real(),1),'precision_matrix':PositiveDefinite(),'scale_tril':LowerCholesky()}#
- has_rsample=True#
- support=IndependentConstraint(Real(),1)#
NegativeBinomial#
- classtorch.distributions.negative_binomial.NegativeBinomial(total_count,probs=None,logits=None,validate_args=None)[source]#
Bases:
DistributionCreates a Negative Binomial distribution, i.e. distributionof the number of successful independent and identical Bernoulli trialsbefore
total_countfailures are achieved. The probabilityof success of each Bernoulli trial isprobs.- Parameters:
- arg_constraints={'logits':Real(),'probs':HalfOpenInterval(lower_bound=0.0,upper_bound=1.0),'total_count':GreaterThanEq(lower_bound=0)}#
- support=IntegerGreaterThan(lower_bound=0)#
Normal#
- classtorch.distributions.normal.Normal(loc,scale,validate_args=None)[source]#
Bases:
ExponentialFamilyCreates a normal (also called Gaussian) distribution parameterized by
locandscale.Example:
>>>m=Normal(torch.tensor([0.0]),torch.tensor([1.0]))>>>m.sample()# normally distributed with loc=0 and scale=1tensor([ 0.1046])
- Parameters:
- arg_constraints={'loc':Real(),'scale':GreaterThan(lower_bound=0.0)}#
- has_rsample=True#
- support=Real()#
OneHotCategorical#
- classtorch.distributions.one_hot_categorical.OneHotCategorical(probs=None,logits=None,validate_args=None)[source]#
Bases:
DistributionCreates a one-hot categorical distribution parameterized by
probsorlogits.Samples are one-hot coded vectors of size
probs.size(-1).Note
Theprobs argument must be non-negative, finite and have a non-zero sum,and it will be normalized to sum to 1 along the last dimension.
probswill return this normalized value.Thelogits argument will be interpreted as unnormalized log probabilitiesand can therefore be any real number. It will likewise be normalized so thatthe resulting probabilities sum to 1 along the last dimension.logitswill return this normalized value.See also:
torch.distributions.Categorical()for specifications ofprobsandlogits.Example:
>>>m=OneHotCategorical(torch.tensor([0.25,0.25,0.25,0.25]))>>>m.sample()# equal probability of 0, 1, 2, 3tensor([ 0., 0., 0., 1.])
- Parameters:
- arg_constraints={'logits':IndependentConstraint(Real(),1),'probs':Simplex()}#
- has_enumerate_support=True#
- support=OneHot()#
Pareto#
- classtorch.distributions.pareto.Pareto(scale,alpha,validate_args=None)[source]#
Bases:
TransformedDistributionSamples from a Pareto Type 1 distribution.
Example:
>>>m=Pareto(torch.tensor([1.0]),torch.tensor([1.0]))>>>m.sample()# sample from a Pareto distribution with scale=1 and alpha=1tensor([ 1.5623])
- Parameters:
- arg_constraints:dict[str,Constraint]={'alpha':GreaterThan(lower_bound=0.0),'scale':GreaterThan(lower_bound=0.0)}#
- propertysupport:Constraint#
- Return type:
_DependentProperty
Poisson#
- classtorch.distributions.poisson.Poisson(rate,validate_args=None)[source]#
Bases:
ExponentialFamilyCreates a Poisson distribution parameterized by
rate, the rate parameter.Samples are nonnegative integers, with a pmf given by
Example:
>>>m=Poisson(torch.tensor([4]))>>>m.sample()tensor([ 3.])
- Parameters:
rate (Number,Tensor) – the rate parameter
- arg_constraints={'rate':GreaterThanEq(lower_bound=0.0)}#
- support=IntegerGreaterThan(lower_bound=0)#
RelaxedBernoulli#
- classtorch.distributions.relaxed_bernoulli.RelaxedBernoulli(temperature,probs=None,logits=None,validate_args=None)[source]#
Bases:
TransformedDistributionCreates a RelaxedBernoulli distribution, parametrized by
temperature, and eitherprobsorlogits(but not both). This is a relaxed version of theBernoulli distribution,so the values are in (0, 1), and has reparametrizable samples.Example:
>>>m=RelaxedBernoulli(torch.tensor([2.2]),...torch.tensor([0.1,0.2,0.3,0.99]))>>>m.sample()tensor([ 0.2951, 0.3442, 0.8918, 0.9021])
- Parameters:
- arg_constraints:dict[str,Constraint]={'logits':Real(),'probs':Interval(lower_bound=0.0,upper_bound=1.0)}#
- base_dist:LogitRelaxedBernoulli#
- has_rsample=True#
- support=Interval(lower_bound=0.0,upper_bound=1.0)#
LogitRelaxedBernoulli#
- classtorch.distributions.relaxed_bernoulli.LogitRelaxedBernoulli(temperature,probs=None,logits=None,validate_args=None)[source]#
Bases:
DistributionCreates a LogitRelaxedBernoulli distribution parameterized by
probsorlogits(but not both), which is the logit of a RelaxedBernoullidistribution.Samples are logits of values in (0, 1). See [1] for more details.
- Parameters:
[1] The Concrete Distribution: A Continuous Relaxation of Discrete RandomVariables (Maddison et al., 2017)
[2] Categorical Reparametrization with Gumbel-Softmax(Jang et al., 2017)
- arg_constraints={'logits':Real(),'probs':Interval(lower_bound=0.0,upper_bound=1.0)}#
- support=Real()#
RelaxedOneHotCategorical#
- classtorch.distributions.relaxed_categorical.RelaxedOneHotCategorical(temperature,probs=None,logits=None,validate_args=None)[source]#
Bases:
TransformedDistributionCreates a RelaxedOneHotCategorical distribution parametrized by
temperature, and eitherprobsorlogits.This is a relaxed version of theOneHotCategoricaldistribution, soits samples are on simplex, and are reparametrizable.Example:
>>>m=RelaxedOneHotCategorical(torch.tensor([2.2]),...torch.tensor([0.1,0.2,0.3,0.4]))>>>m.sample()tensor([ 0.1294, 0.2324, 0.3859, 0.2523])
- Parameters:
- arg_constraints:dict[str,Constraint]={'logits':IndependentConstraint(Real(),1),'probs':Simplex()}#
- base_dist:ExpRelaxedCategorical#
- has_rsample=True#
- support=Simplex()#
StudentT#
- classtorch.distributions.studentT.StudentT(df,loc=0.0,scale=1.0,validate_args=None)[source]#
Bases:
DistributionCreates a Student’s t-distribution parameterized by degree offreedom
df, meanlocand scalescale.Example:
>>>m=StudentT(torch.tensor([2.0]))>>>m.sample()# Student's t-distributed with degrees of freedom=2tensor([ 0.1046])
- Parameters:
- arg_constraints={'df':GreaterThan(lower_bound=0.0),'loc':Real(),'scale':GreaterThan(lower_bound=0.0)}#
- has_rsample=True#
- support=Real()#
TransformedDistribution#
- classtorch.distributions.transformed_distribution.TransformedDistribution(base_distribution,transforms,validate_args=None)[source]#
Bases:
DistributionExtension of the Distribution class, which applies a sequence of Transformsto a base distribution. Let f be the composition of transforms applied:
X~BaseDistributionY=f(X)~TransformedDistribution(BaseDistribution,f)logp(Y)=logp(X)+log|det(dX/dY)|
Note that the
.event_shapeof aTransformedDistributionis themaximum shape of its base distribution and its transforms, since transformscan introduce correlations among events.An example for the usage of
TransformedDistributionwould be:# Building a Logistic Distribution# X ~ Uniform(0, 1)# f = a + b * logit(X)# Y ~ f(X) ~ Logistic(a, b)base_distribution=Uniform(0,1)transforms=[SigmoidTransform().inv,AffineTransform(loc=a,scale=b)]logistic=TransformedDistribution(base_distribution,transforms)
For more examples, please look at the implementations of
Gumbel,HalfCauchy,HalfNormal,LogNormal,Pareto,Weibull,RelaxedBernoulliandRelaxedOneHotCategorical- arg_constraints:dict[str,Constraint]={}#
- cdf(value)[source]#
Computes the cumulative distribution function by inverting thetransform(s) and computing the score of the base distribution.
- icdf(value)[source]#
Computes the inverse cumulative distribution function usingtransform(s) and computing the score of the base distribution.
- log_prob(value)[source]#
Scores the sample by inverting the transform(s) and computing the scoreusing the score of the base distribution and the log abs det jacobian.
- rsample(sample_shape=())[source]#
Generates a sample_shape shaped reparameterized sample or sample_shapeshaped batch of reparameterized samples if the distribution parametersare batched. Samples first from base distribution and appliestransform() for every transform in the list.
- Return type:
- sample(sample_shape=())[source]#
Generates a sample_shape shaped sample or sample_shape shaped batch ofsamples if the distribution parameters are batched. Samples first frombase distribution and appliestransform() for every transform in thelist.
- propertysupport#
- Return type:
_DependentProperty
Uniform#
- classtorch.distributions.uniform.Uniform(low,high,validate_args=None)[source]#
Bases:
DistributionGenerates uniformly distributed random samples from the half-open interval
[low,high).Example:
>>>m=Uniform(torch.tensor([0.0]),torch.tensor([5.0]))>>>m.sample()# uniformly distributed in the range [0.0, 5.0)tensor([ 2.3418])
- Parameters:
- propertyarg_constraints#
- has_rsample=True#
- propertysupport#
- Return type:
_DependentProperty
VonMises#
- classtorch.distributions.von_mises.VonMises(loc,concentration,validate_args=None)[source]#
Bases:
DistributionA circular von Mises distribution.
This implementation uses polar coordinates. The
locandvalueargscan be any real number (to facilitate unconstrained optimization), but areinterpreted as angles modulo 2 pi.- Example::
>>>m=VonMises(torch.tensor([1.0]),torch.tensor([1.0]))>>>m.sample()# von Mises distributed with loc=1 and concentration=1tensor([1.9777])
- Parameters:
loc (torch.Tensor) – an angle in radians.
concentration (torch.Tensor) – concentration parameter
- arg_constraints={'concentration':GreaterThan(lower_bound=0.0),'loc':Real()}#
- has_rsample=False#
- sample(sample_shape=())[source]#
The sampling algorithm for the von Mises distribution is based on thefollowing paper: D.J. Best and N.I. Fisher, “Efficient simulation of thevon Mises distribution.” Applied Statistics (1979): 152-157.
Sampling is always done in double precision internally to avoid a hangin _rejection_sample() for small values of the concentration, whichstarts to happen for single precision around 1e-4 (see issue #88443).
- support=Real()#
Weibull#
- classtorch.distributions.weibull.Weibull(scale,concentration,validate_args=None)[source]#
Bases:
TransformedDistributionSamples from a two-parameter Weibull distribution.
Example
>>>m=Weibull(torch.tensor([1.0]),torch.tensor([1.0]))>>>m.sample()# sample from a Weibull distribution with scale=1, concentration=1tensor([ 0.4784])
- Parameters:
- arg_constraints:dict[str,Constraint]={'concentration':GreaterThan(lower_bound=0.0),'scale':GreaterThan(lower_bound=0.0)}#
- support=GreaterThan(lower_bound=0.0)#
Wishart#
- classtorch.distributions.wishart.Wishart(df,covariance_matrix=None,precision_matrix=None,scale_tril=None,validate_args=None)[source]#
Bases:
ExponentialFamilyCreates a Wishart distribution parameterized by a symmetric positive definite matrix,or its Cholesky decomposition
Example
>>>m=Wishart(torch.Tensor([2]),covariance_matrix=torch.eye(2))>>>m.sample()# Wishart distributed with mean=`df * I` and>>># variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j
- Parameters:
df (float orTensor) – real-valued parameter larger than the (dimension of Square matrix) - 1
covariance_matrix (Tensor) – positive-definite covariance matrix
precision_matrix (Tensor) – positive-definite precision matrix
scale_tril (Tensor) – lower-triangular factor of covariance, with positive-valued diagonal
Note
Only one of
covariance_matrixorprecision_matrixorscale_trilcan be specified.Usingscale_trilwill be more efficient: all computations internallyare based onscale_tril. Ifcovariance_matrixorprecision_matrixis passed instead, it is only used to computethe corresponding lower triangular matrices using a Cholesky decomposition.‘torch.distributions.LKJCholesky’ is a restricted Wishart distribution.[1]References
[1] Wang, Z., Wu, Y. and Chu, H., 2018.On equivalence of the LKJ distribution and the restricted Wishart distribution.[2] Sawyer, S., 2007.Wishart Distributions and Inverse-Wishart Sampling.[3] Anderson, T. W., 2003.An Introduction to Multivariate Statistical Analysis (3rd ed.).[4] Odell, P. L. & Feiveson, A. H., 1966.A Numerical Procedure to Generate a SampleCovariance Matrix. JASA, 61(313):199-203.[5] Ku, Y.-C. & Bloomfield, P., 2010.Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX.
- propertyarg_constraints#
- has_rsample=True#
- rsample(sample_shape=(),max_try_correction=None)[source]#
Warning
In some cases, sampling algorithm based on Bartlett decomposition may return singular matrix samples.Several tries to correct singular samples are performed by default, but it may end up returningsingular matrix samples. Singular samples may return-inf values in.log_prob().In those cases, the user should validate the samples and either fix the value ofdfor adjustmax_try_correction value for argument in.rsample accordingly.
- Return type:
- support=PositiveDefinite()#
KLDivergence#
- torch.distributions.kl.kl_divergence(p,q)[source]#
Compute Kullback-Leibler divergence between two distributions.
- Parameters:
p (Distribution) – A
Distributionobject.q (Distribution) – A
Distributionobject.
- Returns:
A batch of KL divergences of shapebatch_shape.
- Return type:
- Raises:
NotImplementedError – If the distribution types have not been registered via
register_kl().
- KL divergence is currently implemented for the following distribution pairs:
BernoulliandBernoulliBernoulliandPoissonBetaandBetaBetaandContinuousBernoulliBetaandExponentialBetaandGammaBetaandNormalBetaandParetoBetaandUniformBinomialandBinomialCategoricalandCategoricalCauchyandCauchyContinuousBernoulliandContinuousBernoulliContinuousBernoulliandExponentialContinuousBernoulliandNormalContinuousBernoulliandParetoContinuousBernoulliandUniformDirichletandDirichletExponentialandBetaExponentialandContinuousBernoulliExponentialandExponentialExponentialandGammaExponentialandGumbelExponentialandNormalExponentialandParetoExponentialandUniformExponentialFamilyandExponentialFamilyGammaandBetaGammaandContinuousBernoulliGammaandExponentialGammaandGammaGammaandGumbelGammaandNormalGammaandParetoGammaandUniformGeometricandGeometricGumbelandBetaGumbelandContinuousBernoulliGumbelandExponentialGumbelandGammaGumbelandGumbelGumbelandNormalGumbelandParetoGumbelandUniformHalfNormalandHalfNormalIndependentandIndependentLaplaceandBetaLaplaceandContinuousBernoulliLaplaceandExponentialLaplaceandGammaLaplaceandLaplaceLaplaceandNormalLaplaceandParetoLaplaceandUniformLowRankMultivariateNormalandLowRankMultivariateNormalLowRankMultivariateNormalandMultivariateNormalMultivariateNormalandLowRankMultivariateNormalMultivariateNormalandMultivariateNormalNormalandBetaNormalandContinuousBernoulliNormalandExponentialNormalandGammaNormalandGumbelNormalandLaplaceNormalandNormalNormalandParetoNormalandUniformOneHotCategoricalandOneHotCategoricalParetoandBetaParetoandContinuousBernoulliParetoandExponentialParetoandGammaParetoandNormalParetoandParetoParetoandUniformPoissonandBernoulliPoissonandBinomialPoissonandPoissonTransformedDistributionandTransformedDistributionUniformandBetaUniformandContinuousBernoulliUniformandExponentialUniformandGammaUniformandGumbelUniformandNormalUniformandParetoUniformandUniform
- torch.distributions.kl.register_kl(type_p,type_q)[source]#
Decorator to register a pairwise function with
kl_divergence().Usage:@register_kl(Normal,Normal)defkl_normal_normal(p,q):# insert implementation here
Lookup returns the most specific (type,type) match ordered by subclass. Ifthe match is ambiguous, aRuntimeWarning is raised. For example toresolve the ambiguous situation:
@register_kl(BaseP,DerivedQ)defkl_version1(p,q):...@register_kl(DerivedP,BaseQ)defkl_version2(p,q):...
you should register a third most-specific implementation, e.g.:
register_kl(DerivedP,DerivedQ)(kl_version1)# Break the tie.
Transforms#
- classtorch.distributions.transforms.AffineTransform(loc,scale,event_dim=0,cache_size=0)[source]#
Transform via the pointwise affine mapping.
- classtorch.distributions.transforms.CatTransform(tseq,dim=0,lengths=None,cache_size=0)[source]#
Transform functor that applies a sequence of transformstseqcomponent-wise to each submatrix atdim, of lengthlengths[dim],in a way compatible with
torch.cat().Example:
x0=torch.cat([torch.range(1,10),torch.range(1,10)],dim=0)x=torch.cat([x0,x0],dim=0)t0=CatTransform([ExpTransform(),identity_transform],dim=0,lengths=[10,10])t=CatTransform([t0,t0],dim=0,lengths=[20,20])y=t(x)
- classtorch.distributions.transforms.ComposeTransform(parts,cache_size=0)[source]#
Composes multiple transforms in a chain.The transforms being composed are responsible for caching.
- classtorch.distributions.transforms.CorrCholeskyTransform(cache_size=0)[source]#
Transforms an unconstrained real vector with length into theCholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lowertriangular matrix with positive diagonals and unit Euclidean norm for each row.The transform is processed as follows:
First we convert x into a lower triangular matrix in row order.
For each row of the lower triangular part, we apply asigned version ofclass
StickBreakingTransformto transform into aunit Euclidean length vector using the following steps:- Scales into the interval domain:.- Transforms into an unsigned domain:.- Applies.- Transforms back into signed domain:.
- classtorch.distributions.transforms.CumulativeDistributionTransform(distribution,cache_size=0)[source]#
Transform via the cumulative distribution function of a probability distribution.
- Parameters:
distribution (Distribution) – Distribution whose cumulative distribution function to use forthe transformation.
Example:
# Construct a Gaussian copula from a multivariate normal.base_dist=MultivariateNormal(loc=torch.zeros(2),scale_tril=LKJCholesky(2).sample(),)transform=CumulativeDistributionTransform(Normal(0,1))copula=TransformedDistribution(base_dist,[transform])
- classtorch.distributions.transforms.IndependentTransform(base_transform,reinterpreted_batch_ndims,cache_size=0)[source]#
Wrapper around another transform to treat
reinterpreted_batch_ndims-many extra of the right most dimensions asdependent. This has no effect on the forward or backward transforms, butdoes sum outreinterpreted_batch_ndims-many of the rightmost dimensionsinlog_abs_det_jacobian().
- classtorch.distributions.transforms.LowerCholeskyTransform(cache_size=0)[source]#
Transform from unconstrained matrices to lower-triangular matrices withnonnegative diagonal entries.
This is useful for parameterizing positive definite matrices in terms oftheir Cholesky factorization.
- classtorch.distributions.transforms.PositiveDefiniteTransform(cache_size=0)[source]#
Transform from unconstrained matrices to positive-definite matrices.
- classtorch.distributions.transforms.PowerTransform(exponent,cache_size=0)[source]#
Transform via the mapping.
- classtorch.distributions.transforms.ReshapeTransform(in_shape,out_shape,cache_size=0)[source]#
Unit Jacobian transform to reshape the rightmost part of a tensor.
Note that
in_shapeandout_shapemust have the same number ofelements, just as fortorch.Tensor.reshape().- Parameters:
in_shape (torch.Size) – The input event shape.
out_shape (torch.Size) – The output event shape.
cache_size (int) – Size of cache. If zero, no caching is done. If one,the latest single value is cached. Only 0 and 1 are supported. (Default 0.)
- classtorch.distributions.transforms.SigmoidTransform(cache_size=0)[source]#
Transform via the mapping and.
- classtorch.distributions.transforms.SoftplusTransform(cache_size=0)[source]#
Transform via the mapping.The implementation reverts to the linear function when.
- classtorch.distributions.transforms.TanhTransform(cache_size=0)[source]#
Transform via the mapping.
It is equivalent to
ComposeTransform([AffineTransform(0.0,2.0),SigmoidTransform(),AffineTransform(-1.0,2.0),])
However this might not be numerically stable, thus it is recommended to useTanhTransforminstead.
Note that one should usecache_size=1 when it comes toNaN/Inf values.
- classtorch.distributions.transforms.SoftmaxTransform(cache_size=0)[source]#
Transform from unconstrained space to the simplex via thennormalizing.
This is not bijective and cannot be used for HMC. However this acts mostlycoordinate-wise (except for the final normalization), and thus isappropriate for coordinate-wise optimization algorithms.
- classtorch.distributions.transforms.StackTransform(tseq,dim=0,cache_size=0)[source]#
Transform functor that applies a sequence of transformstseqcomponent-wise to each submatrix atdimin a way compatible with
torch.stack().Example:
x=torch.stack([torch.range(1,10),torch.range(1,10)],dim=1)t=StackTransform([ExpTransform(),identity_transform],dim=1)y=t(x)
- classtorch.distributions.transforms.StickBreakingTransform(cache_size=0)[source]#
Transform from unconstrained space to the simplex of one additionaldimension via a stick-breaking process.
This transform arises as an iterated sigmoid transform in a stick-breakingconstruction of theDirichlet distribution: the first logit istransformed via sigmoid to the first probability and the probability ofeverything else, and then the process recurses.
This is bijective and appropriate for use in HMC; however it mixescoordinates together and is less appropriate for optimization.
- classtorch.distributions.transforms.Transform(cache_size=0)[source]#
Abstract class for invertable transformations with computable logdet jacobians. They are primarily used in
torch.distributions.TransformedDistribution.Caching is useful for transforms whose inverses are either expensive ornumerically unstable. Note that care must be taken with memoized valuessince the autograd graph may be reversed. For example while the followingworks with or without caching:
y=t(x)t.log_abs_det_jacobian(x,y).backward()# x will receive gradients.
However the following will error when caching due to dependency reversal:
y=t(x)z=t.inv(y)grad(z.sum(),[y])# error because z is x
Derived classes should implement one or both of
_call()or_inverse(). Derived classes that setbijective=True should alsoimplementlog_abs_det_jacobian().- Parameters:
cache_size (int) – Size of cache. If zero, no caching is done. If one,the latest single value is cached. Only 0 and 1 are supported.
- Variables:
domain (
Constraint) – The constraint representing valid inputs to this transform.codomain (
Constraint) – The constraint representing valid outputs to this transformwhich are inputs to the inverse transform.bijective (bool) – Whether this transform is bijective. A transform
tis bijective ifft.inv(t(x))==xandt(t.inv(y))==yfor everyxin the domain andyinthe codomain. Transforms that are not bijective should at leastmaintain the weaker pseudoinverse propertiest(t.inv(t(x))==t(x)andt.inv(t(t.inv(y)))==t.inv(y).sign (int orTensor) – For bijective univariate transforms, thisshould be +1 or -1 depending on whether transform is monotoneincreasing or decreasing.
- propertyinv:Transform#
Returns the inverse
Transformof this transform.This should satisfyt.inv.invist.
- propertysign:int#
Returns the sign of the determinant of the Jacobian, if applicable.In general this only makes sense for bijective transforms.
Constraints#
- classtorch.distributions.constraints.Constraint[source]#
Abstract base class for constraints.
A constraint object represents a region over which a variable is valid,e.g. within which a variable can be optimized.
- Variables:
- torch.distributions.constraints.is_dependent(constraint)[source]#
Checks if
constraintis a_Dependentobject.- Parameters:
constraint – A
Constraintobject.- Returns:
True if
constraintcan be refined to the type_Dependent, False otherwise.- Return type:
bool
Examples
>>>importtorch>>>fromtorch.distributionsimportBernoulli>>>fromtorch.distributions.constraintsimportis_dependent
>>>dist=Bernoulli(probs=torch.tensor([0.6],requires_grad=True))>>>constraint1=dist.arg_constraints["probs"]>>>constraint2=dist.arg_constraints["logits"]
>>>forconstraintin[constraint1,constraint2]:>>>ifis_dependent(constraint):>>>continue
- classtorch.distributions.constraints.MixtureSameFamilyConstraint(base_constraint)[source]#
Constraint for the
MixtureSameFamilydistribution that adds back the rightmost batch dimension beforeperforming the validity check with the component distributionconstraint.- Parameters:
base_constraint – The
Constraintobject ofthe component distribution oftheMixtureSameFamilydistribution.
ConstraintRegistry#
PyTorch provides two globalConstraintRegistry objects that linkConstraint objects toTransform objects. These objects bothinput constraints and return transforms, but they have different guarantees onbijectivity.
biject_to(constraint)looks up a bijectiveTransformfromconstraints.realto the givenconstraint. The returned transform is guaranteed to have.bijective=Trueand should implement.log_abs_det_jacobian().transform_to(constraint)looks up a not-necessarily bijectiveTransformfromconstraints.realto the givenconstraint. The returned transform is not guaranteed toimplement.log_abs_det_jacobian().
Thetransform_to() registry is useful for performing unconstrainedoptimization on constrained parameters of probability distributions, which areindicated by each distribution’s.arg_constraints dict. These transforms oftenoverparameterize a space in order to avoid rotation; they are thus moresuitable for coordinate-wise optimization algorithms like Adam:
loc=torch.zeros(100,requires_grad=True)unconstrained=torch.zeros(100,requires_grad=True)scale=transform_to(Normal.arg_constraints["scale"])(unconstrained)loss=-Normal(loc,scale).log_prob(data).sum()
Thebiject_to() registry is useful for Hamiltonian Monte Carlo, wheresamples from a probability distribution with constrained.support arepropagated in an unconstrained space, and algorithms are typically rotationinvariant.:
dist=Exponential(rate)unconstrained=torch.zeros(100,requires_grad=True)sample=biject_to(dist.support)(unconstrained)potential_energy=-dist.log_prob(sample).sum()
Note
An example wheretransform_to andbiject_to differ isconstraints.simplex:transform_to(constraints.simplex) returns aSoftmaxTransform that simplyexponentiates and normalizes its inputs; this is a cheap and mostlycoordinate-wise operation appropriate for algorithms like SVI. Incontrast,biject_to(constraints.simplex) returns aStickBreakingTransform thatbijects its input down to a one-fewer-dimensional space; this a moreexpensive less numerically stable transform but is needed for algorithmslike HMC.
Thebiject_to andtransform_to objects can be extended by user-definedconstraints and transforms using their.register() method either as afunction on singleton constraints:
transform_to.register(my_constraint,my_transform)
or as a decorator on parameterized constraints:
@transform_to.register(MyConstraintClass)defmy_factory(constraint):assertisinstance(constraint,MyConstraintClass)returnMyTransform(constraint.param1,constraint.param2)
You can create your own registry by creating a newConstraintRegistryobject.
- classtorch.distributions.constraint_registry.ConstraintRegistry[source]#
Registry to link constraints to transforms.
- register(constraint,factory=None)[source]#
Registers a
Constraintsubclass in this registry. Usage:@my_registry.register(MyConstraintClass)defconstruct_transform(constraint):assertisinstance(constraint,MyConstraint)returnMyTransform(constraint.arg_constraints)
- Parameters:
constraint (subclass of
Constraint) – A subclass ofConstraint, ora singleton object of the desired class.factory (Callable) – A callable that inputs a constraint object and returnsa
Transformobject.