Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitac73b55

Browse files
committed
Allow truncation of hurdle distributions
1 parent3f42edd commitac73b55

File tree

3 files changed

+70
-8
lines changed

3 files changed

+70
-8
lines changed

‎pymc/distributions/mixture.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
)
4040
frompymc.distributions.shape_utilsimport_change_dist_size,change_dist_size,rv_size_is_none
4141
frompymc.distributions.transformsimport_default_transform
42-
frompymc.distributions.truncatedimportTruncated
4342
frompymc.logprob.abstractimport_logcdf,_logcdf_helper,_logprob
4443
frompymc.logprob.basicimportlogp
4544
frompymc.logprob.transformsimportIntervalTransform
@@ -831,6 +830,8 @@ def _create(cls, *, name, nonzero_p, nonzero_dist, max_n_steps=10_000, **kwargs)
831830
832831
Note: this is invalid for discrete nonzero distributions with mass below 0, as we simply truncate[lower=1].
833832
"""
833+
frompymc.distributions.truncatedimportTruncated
834+
834835
dtype=nonzero_dist.dtype
835836

836837
ifdtype.startswith("int"):

‎pymc/distributions/truncated.py‎

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
_support_point,
3636
support_point,
3737
)
38+
frompymc.distributions.mixtureimport_HurdleRV
3839
frompymc.distributions.shape_utilsimport (
3940
_change_dist_size,
4041
change_dist_size,
@@ -79,7 +80,9 @@ def rv_op(cls, dist, lower, upper, max_n_steps, *, size=None):
7980

8081
# Try to use specialized Op
8182
try:
82-
return_truncated(dist.owner.op,lower,upper,size,*dist.owner.inputs)
83+
return_truncated(
84+
dist.owner.op,lower,upper,size,*dist.owner.inputs,max_n_steps=max_n_steps
85+
)
8386
exceptNotImplementedError:
8487
pass
8588

@@ -222,7 +225,7 @@ def update(self, node: Apply):
222225

223226

224227
@singledispatch
225-
def_truncated(op:Op,lower,upper,size,*params):
228+
def_truncated(op:Op,lower,upper,size,*params,max_n_steps:int):
226229
"""Return the truncated equivalent of another `RandomVariable`."""
227230
raiseNotImplementedError(f"{op} does not have an equivalent truncated version implemented")
228231

@@ -307,13 +310,14 @@ def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs)
307310
f"Truncation dist must be a distribution created via the `.dist()` API, got{type(dist)}"
308311
)
309312

310-
if (
311-
isinstance(dist.owner.op,SymbolicRandomVariable)
312-
and"[size]"notindist.owner.op.extended_signature
313+
ifisinstance(dist.owner.op,SymbolicRandomVariable)andnot (
314+
"[size]"indist.owner.op.extended_signature
315+
# If there's a specific _truncated dispatch for this RV, that's also fine
316+
or_truncated.dispatch(type(dist.owner.op))isnot_truncated.dispatch(object)
313317
):
314318
# Truncation needs to wrap the underlying dist, but not all SymbolicRandomVariables encapsulate the whole
315319
# random graph and as such we don't know where the actual inputs begin. This happens mostly for
316-
# distribution factories like `Censored`and `Mixture`which would have a very complex signature if they
320+
# distribution factories like `Censored` which would have a very complex signature if they
317321
# encapsulated the random components instead of taking them as inputs like they do now.
318322
# SymbolicRandomVariables that encapsulate the whole random graph can be identified for having a size parameter.
319323
raiseNotImplementedError(f"Truncation not implemented for{dist.owner.op}")
@@ -462,7 +466,7 @@ def truncated_logcdf(op: TruncatedRV, value, *inputs, **kwargs):
462466

463467

464468
@_truncated.register(NormalRV)
465-
def_truncated_normal(op,lower,upper,size,rng,old_size,mu,sigma):
469+
def_truncated_normal(op,lower,upper,size,rng,old_size,mu,sigma,*,max_n_steps):
466470
returnTruncatedNormal.dist(
467471
mu=mu,
468472
sigma=sigma,
@@ -472,3 +476,32 @@ def _truncated_normal(op, lower, upper, size, rng, old_size, mu, sigma):
472476
size=size,
473477
dtype=op.dtype,
474478
)
479+
480+
481+
@_truncated.register(_HurdleRV)
482+
def_truncated_hurdle(
483+
op:_HurdleRV,lower,upper,size,rng,weights,zero_dist,dist,max_n_steps
484+
):
485+
# If the DiracDelta value is outside the truncation bounds, this is effectively a non-hurdle distribution
486+
# We achieve this by adjusting the weights of the DiracDelta component, so it's never selected in that case
487+
psi=weights[...,1]
488+
489+
checks=np.array(True)
490+
iflowerisnotNone:
491+
checks&=lower<=0
492+
ifupperisnotNone:
493+
checks&=0<=upper
494+
495+
adjusted_psi=pt.where(
496+
checks,
497+
psi,
498+
1,
499+
)
500+
adjusted_weights=pt.stack([1-adjusted_psi,adjusted_psi],axis=-1)
501+
502+
# The only remaining step is to truncate the other distribution
503+
truncated_dist=Truncated.dist(dist,lower=lower,upper=upper,max_n_steps=max_n_steps)
504+
505+
# Creating a hurdle with the adjusted weights and the truncated distribution
506+
# Should be equivalent to truncating the original hurdle distribution
507+
returnop.rv_op(adjusted_weights,zero_dist,truncated_dist,size=size)

‎tests/distributions/test_mixture.py‎

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
Poisson,
5050
StickBreakingWeights,
5151
Triangular,
52+
Truncated,
5253
Uniform,
5354
ZeroInflatedBinomial,
5455
ZeroInflatedNegativeBinomial,
@@ -1710,3 +1711,30 @@ def logp_fn(value, psi, mu, sigma):
17101711
returnnp.log(psi)+st.lognorm.logpdf(value,sigma,0,np.exp(mu))
17111712

17121713
check_logp(HurdleLogNormal,Rplus, {"psi":Unit,"mu":R,"sigma":Rplusbig},logp_fn)
1714+
1715+
@pytest.mark.parametrize("lower", (-np.inf,0,None,1))
1716+
deftest_truncated_hurdle_lognormal(self,lower):
1717+
psi=0.7
1718+
x=HurdleLogNormal.dist(psi=psi,mu=3,sigma=1)
1719+
x_trunc=Truncated.dist(x,lower=lower,upper=30,size=(1000,))
1720+
1721+
x_trunc_draws=draw(x_trunc)
1722+
assert ((x_trunc_draws>= (loweror-np.inf))& (x_trunc_draws<=30)).all()
1723+
1724+
x_trunc=Truncated.dist(x,lower=lower,upper=30,size=(4,))
1725+
x_trunc_logp=logp(x_trunc, [0,5.5,30.0,30.1]).eval()
1726+
effective_psi=psiif (loweror-np.inf)<=0else1
1727+
np.testing.assert_allclose(
1728+
x_trunc_logp,
1729+
[
1730+
np.log(1-effective_psi),# 0 is not in the support of the distribution
1731+
*(
1732+
np.log(effective_psi)
1733+
+logp(
1734+
Truncated.dist(LogNormal.dist(mu=3,sigma=1),lower=lower,upper=30),
1735+
[5.5,30.0],
1736+
)
1737+
).eval(),
1738+
-np.inf,# 30.1 is outside the upper bound
1739+
],
1740+
)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp