- Notifications
You must be signed in to change notification settings - Fork2.2k
Do not include rvs in symbolic normalizing constant#7787
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Do not include rvs in symbolic normalizing constant#7787
Uh oh!
There was an error while loading.Please reload this page.
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Pull Request Overview
This PR ensures that random variables (RVs) are not included in the symbolic normalizing constant graph by folding shapes to constants, adds shape inference for minibatch RVs, includes a test for the new behavior, and fixes a small typo.
- Use
constant_foldto derive batch shapes instead of carrying RVs intosymbolic_normalizing_constant - Implement
infer_shapeonMinibatchRandomVariableso shape propagation works correctly - Add a dedicated test (
assert_no_rvs) to confirm no RVs appear in the symbolic normalizing constant - Correct a typo in the
constant_foldcomment
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| tests/variational/test_opvi.py | Addedtest_symbolic_normalizing_constant_no_rvs withassert_no_rvs |
| pymc/variational/opvi.py | Swapped direct.shape usage forconstant_fold([...].shape) in scaling |
| pymc/variational/minibatch_rv.py | Addedinfer_shape method to propagate shapes without evaluation |
| pymc/pytensorf.py | Fixed typo in comment (constand_folding →constant_folding) |
Comments suppressed due to low confidence (3)
tests/variational/test_opvi.py:284
- [nitpick] The test verifies no RVs are in the graph but doesn't assert that the symbolic normalizing constant still produces the expected scalar or tensor shape. Consider adding an assertion on the returned value or shape to guard against regressions.
def test_symbolic_normalizing_constant_no_rvs():pymc/variational/opvi.py:1109
- Calling
constant_foldinside the list comprehension for each RV will repeatedly clone and rewrite the graph, which may be costly. Consider computing all shapes once (e.g., collect inputs, callconstant_foldoutside the loop) or caching results before the comprehension.
get_scaling(pymc/variational/opvi.py:1279
- This mirrored use of
constant_foldin another list comprehension also risks redundant graph rewriting. Extract a helper or hoist the folding step to improve efficiency and reduce duplicated logic.
get_scaling(Uh oh!
There was an error while loading.Please reload this page.
codecovbot commentedMay 17, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@## main #7787 +/- ##======================================= Coverage 92.84% 92.84% ======================================= Files 107 107 Lines 18378 18380 +2 =======================================+ Hits 17063 17065 +2 Misses 1315 1315
🚀 New features to boost your workflow:
|
Uh oh!
There was an error while loading.Please reload this page.
2551adf to0867cdeCompare618634b intopymc-devs:mainUh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
This is a more proper fix for the problem highlighted in#7778
The normalizing constant for MinibatchRVs included the graph of the shape of the RVs.
Even though the shape of the MinibatchRV can be derived without evaluating the draws, passing any graph with RVs to
pytensorf.compilewill automatically integrate the updates which requires evaluating the RV anyway. This PR makes sure we don't include the RVs only to get the symbolic normalizing constant.📚 Documentation preview 📚:https://pymc--7787.org.readthedocs.build/en/7787/