3535_support_point ,
3636support_point ,
3737)
38+ from pymc .distributions .mixture import _HurdleRV
3839from pymc .distributions .shape_utils import (
3940_change_dist_size ,
4041change_dist_size ,
@@ -79,7 +80,9 @@ def rv_op(cls, dist, lower, upper, max_n_steps, *, size=None):
7980
8081# Try to use specialized Op
8182try :
82- return _truncated (dist .owner .op ,lower ,upper ,size ,* dist .owner .inputs )
83+ return _truncated (
84+ dist .owner .op ,lower ,upper ,size ,* dist .owner .inputs ,max_n_steps = max_n_steps
85+ )
8386except NotImplementedError :
8487pass
8588
@@ -222,7 +225,7 @@ def update(self, node: Apply):
222225
223226
224227@singledispatch
225- def _truncated (op :Op ,lower ,upper ,size ,* params ):
228+ def _truncated (op :Op ,lower ,upper ,size ,* params , max_n_steps : int ):
226229"""Return the truncated equivalent of another `RandomVariable`."""
227230raise NotImplementedError (f"{ op } does not have an equivalent truncated version implemented" )
228231
@@ -307,13 +310,14 @@ def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs)
307310f"Truncation dist must be a distribution created via the `.dist()` API, got{ type (dist )} "
308311 )
309312
310- if (
311- isinstance (dist .owner .op ,SymbolicRandomVariable )
312- and "[size]" not in dist .owner .op .extended_signature
313+ if isinstance (dist .owner .op ,SymbolicRandomVariable )and not (
314+ "[size]" in dist .owner .op .extended_signature
315+ # If there's a specific _truncated dispatch for this RV, that's also fine
316+ or _truncated .dispatch (type (dist .owner .op ))is not _truncated .dispatch (object )
313317 ):
314318# Truncation needs to wrap the underlying dist, but not all SymbolicRandomVariables encapsulate the whole
315319# random graph and as such we don't know where the actual inputs begin. This happens mostly for
316- # distribution factories like `Censored`and `Mixture` which would have a very complex signature if they
320+ # distribution factories like `Censored` which would have a very complex signature if they
317321# encapsulated the random components instead of taking them as inputs like they do now.
318322# SymbolicRandomVariables that encapsulate the whole random graph can be identified for having a size parameter.
319323raise NotImplementedError (f"Truncation not implemented for{ dist .owner .op } " )
@@ -462,7 +466,7 @@ def truncated_logcdf(op: TruncatedRV, value, *inputs, **kwargs):
462466
463467
464468@_truncated .register (NormalRV )
465- def _truncated_normal (op ,lower ,upper ,size ,rng ,old_size ,mu ,sigma ):
469+ def _truncated_normal (op ,lower ,upper ,size ,rng ,old_size ,mu ,sigma , * , max_n_steps ):
466470return TruncatedNormal .dist (
467471mu = mu ,
468472sigma = sigma ,
@@ -472,3 +476,32 @@ def _truncated_normal(op, lower, upper, size, rng, old_size, mu, sigma):
472476size = size ,
473477dtype = op .dtype ,
474478 )
479+
480+
481+ @_truncated .register (_HurdleRV )
482+ def _truncated_hurdle (
483+ op :_HurdleRV ,lower ,upper ,size ,rng ,weights ,zero_dist ,dist ,max_n_steps
484+ ):
485+ # If the DiracDelta value is outside the truncation bounds, this is effectively a non-hurdle distribution
486+ # We achieve this by adjusting the weights of the DiracDelta component, so it's never selected in that case
487+ psi = weights [...,1 ]
488+
489+ checks = np .array (True )
490+ if lower is not None :
491+ checks &= lower <= 0
492+ if upper is not None :
493+ checks &= 0 <= upper
494+
495+ adjusted_psi = pt .where (
496+ checks ,
497+ psi ,
498+ 1 ,
499+ )
500+ adjusted_weights = pt .stack ([1 - adjusted_psi ,adjusted_psi ],axis = - 1 )
501+
502+ # The only remaining step is to truncate the other distribution
503+ truncated_dist = Truncated .dist (dist ,lower = lower ,upper = upper ,max_n_steps = max_n_steps )
504+
505+ # Creating a hurdle with the adjusted weights and the truncated distribution
506+ # Should be equivalent to truncating the original hurdle distribution
507+ return op .rv_op (adjusted_weights ,zero_dist ,truncated_dist ,size = size )