Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit0867cde

Browse files
committed
Do not include RVs in graph of symbolic_normalizing_constant
1 parentb3d575f commit0867cde

File tree

3 files changed

+26
-3
lines changed

3 files changed

+26
-3
lines changed

‎pymc/pytensorf.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -979,7 +979,7 @@ def constant_fold(
979979
"""
980980
fg=FunctionGraph(outputs=xs,features=[ShapeFeature()],copy_inputs=False,clone=True)
981981

982-
# The default rewrite_graph includes aconstand_folding that is not always applied.
982+
# The default rewrite_graph includes aconstant_folding that is not always applied.
983983
# We use an unconditional constant_folding as the last pass to ensure a thorough constant folding.
984984
rewrite_graph(fg)
985985
topo_unconditional_constant_folding.apply(fg)

‎pymc/variational/opvi.py‎

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
frompymc.pytensorfimport (
7575
SeedSequenceSeed,
7676
compile,
77+
constant_fold,
7778
find_rng_nodes,
7879
reseed_rngs,
7980
)
@@ -1105,7 +1106,10 @@ def symbolic_normalizing_constant(self):
11051106
t=self.to_flat_input(
11061107
pt.max(
11071108
[
1108-
get_scaling(v.owner.inputs[1:],v.shape)
1109+
get_scaling(
1110+
v.owner.inputs[1:],
1111+
constant_fold([v.owner.inputs[0].shape],raise_not_constant=False),
1112+
)
11091113
forvinself.group
11101114
ifisinstance(v.owner.op,MinibatchRandomVariable)
11111115
]
@@ -1272,7 +1276,10 @@ def symbolic_normalizing_constant(self):
12721276
t=pt.max(
12731277
self.collect("symbolic_normalizing_constant")
12741278
+ [
1275-
get_scaling(obs.owner.inputs[1:],obs.shape)
1279+
get_scaling(
1280+
obs.owner.inputs[1:],
1281+
constant_fold([obs.owner.inputs[0].shape],raise_not_constant=False),
1282+
)
12761283
forobsinself.model.observed_RVs
12771284
ifisinstance(obs.owner.op,MinibatchRandomVariable)
12781285
]

‎tests/variational/test_opvi.py‎

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
importpymcaspm
2222

23+
frompymc.testingimportassert_no_rvs
2324
frompymc.variationalimportopvi
2425
frompymc.variational.approximationsimport (
2526
Empirical,
@@ -278,3 +279,18 @@ def test_logq_globals(three_var_approx):
278279
es=symbolic_logq.eval()
279280
asserte.shape== ()
280281
assertes.shape== (2,)
282+
283+
284+
deftest_symbolic_normalizing_constant_no_rvs():
285+
# Test that RVs aren't included in the graph of symbolic_normalizing_constant
286+
rng=np.random.default_rng()
287+
288+
withpm.Model()asm:
289+
obs=pm.Data("obs",rng.normal(size=(1000,)))
290+
obs_batch=pm.Minibatch(obs,batch_size=128)
291+
x=pm.Normal("x")# Need at least one Free_RV in the graph
292+
y_hat=pm.Flat("y_hat",observed=obs_batch,total_size=1000)
293+
294+
step=pm.ADVI()
295+
296+
assert_no_rvs(step.approx.symbolic_normalizing_constant)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp