Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

feat: batched sampling for MCMC#1176

Merged
janfb merged 81 commits intomainfromamortized_sample_mcmc
Jul 30, 2024
Merged

feat: batched sampling for MCMC#1176

janfb merged 81 commits intomainfromamortized_sample_mcmc
Jul 30, 2024

Conversation

manuelgloeckler
Copy link
Contributor

@manuelgloecklermanuelgloeckler commentedJun 18, 2024
edited by gmoss13
Loading

What does this implement/fix? Explain your changes

This pull request aims to implement thesample_batched method for MCMC.

Current problem

  • BasePotential can either "allow_iid" or not. Hence, each batch dimension will be interpreted as IID samples.
    • Replaceallow_iid with a mutable attribute (or optional input argument)interpret_as_iid.
    • Remove warning for batched x and default to batched evaluation
  • Refactor all MCMC initialization methods to work with batch dim.
    • resample should break
    • SIR should break
    • proposal should work
  • Add tests to check if correct samples are in each dimension (currently, only shapes are checked)
    • The problem is currently not catched by tests...

The current implementation will let you sample the correct shape, BUT will output the wrong solution. This is because the potential function will broadcast, repeat and finally sum up the first dimension which is incorrect.

manuelgloecklerand others added30 commitsApril 29, 2024 09:04
…posteriors' into amortizedsample"This reverts commit07084e2, reversingchanges made tof16622d.
…from-different-posteriors' into amortizedsample
@gmoss13
Copy link
Contributor

I've made some progress now towards this PR, and would like some feedback before I continue.

BasePotential can either "allow_iid" or not.

Givenbatch_dim_theta!=batch_dim_x, we need to decide how to interpret how to evaluatepotential(x,theta). We could return(batch_dim_x,batch_dim_theta) potentials (i.e. every combination), but I am worried this can add a lot of computational overhead, especially when sampling. Instead, the current implementation I suggest that we assume thatbatch_dim_theta is a multiple ofbatch_dim_x (i.e. for sampling, we have n chains in theta for eachx). In this case we expand the batch dim ofx tobatch_theta, and match whichx goes to whichtheta. If we are happy with this approach, I'll go ahead and apply this also to the MCMCinit_strategy, etc., and make sure this is consistent with other calls.

Remove warning for batched x and default to batched evaluation
Not sure if we want batched evaluation as the default. I think it's easier to do batched evaluation whensample_batched orlog_prob_batched is called, and otherwise assume iid (and warn if batch dim >1 as before).

@gmoss13gmoss13 requested a review fromjanfbJune 27, 2024 16:04
@manuelgloeckler
Copy link
ContributorAuthor

Great, it looks good. I like that the choice on iid or not can now be made at theset_x method which makes a lot of sense.

I would also opt for your suggested option. The question arises because we squeeze the batch_shape into a single dimension, right? For "PyTorch" broadcasting, one would expect something like (1,batch_x_dim, x_dim) and (batch_theta_dim, betach_x_dim, theta_dim) -> (batch_x_dim, batch_theta_dim), so by squeezing the xs, thetas into 2d one would always get a dimension that is a multiple of batch_x_dim (otherwise it cannot be represented by a fixed size tensor).

For (1,batch_x_dim,x_dim) and (batch_theta_dim, 1, theta_dim), PyTorch broadcasting semantics would compute all combinations. Unfortunately, after squeezing, these distinctions between cases can no longer be fully preserved.

Copy link
Contributor

@janfbjanfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Great effort, thanks a lot for tacking this 👏

I do have a couple of comments and questions. Happy to discuss in person if needed.

@gmoss13
Copy link
Contributor

Great effort, thanks a lot for tacking this 👏

I do have a couple of comments and questions. Happy to discuss in person if needed.

Thanks for the review! I implemented your suggestions.

An additional point - Forposterior_based_potential, indeed we should not allow foriid_x, as this is handled byPermutationInvariantNetwork. Instead, we now always treatx batches as not iid. If the user tries to setpotential.set_x(x,x_is_iid=True) with aPosteriorBasedPotential, we raise an error stating this. I added a few test cases inembedding_net_test.py::test_embedding_api_with_multiple_trials to test whether batches ofx are interpreted correctly when we use aPermutationInvariantNetwork.

@gmoss13gmoss13 requested a review fromjanfbJuly 19, 2024 15:51
Copy link
Contributor

@janfbjanfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Looks great! I added just a couple of last questions..

@gmoss13gmoss13 requested a review fromjanfbJuly 30, 2024 08:11
Copy link
Contributor

@janfbjanfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Looks good! Thanks a lot, great effort!

@janfbjanfb self-assigned thisJul 30, 2024
@janfbjanfb added the enhancementNew feature or request labelJul 30, 2024
@janfb
Copy link
Contributor

closes#990
closes#944

@janfbjanfb merged commit81fffcf intomainJul 30, 2024
5 of 6 checks passed
@janfbjanfb deleted the amortized_sample_mcmc branchJuly 30, 2024 09:24
Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment
Reviewers

@gmoss13gmoss13gmoss13 left review comments

@janfbjanfbjanfb approved these changes

Assignees

@janfbjanfb

Labels
enhancementNew feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Addsample_batched andlog_prob_batched to posteriors Allow sampling the posterior given differentx (batched)
4 participants
@manuelgloeckler@gmoss13@janfb@michaeldeistler

[8]ページ先頭

©2009-2025 Movatter.jp