- Notifications
You must be signed in to change notification settings - Fork0
Conditional Molecular Structure Generation
License
Ferg-Lab/molgen
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
This package implements Generative Adversarial Networks (GANs) and Denoising Diffusion Probabilistic Models (DDPMs) for generative tasks such as conditional molecular structure generation.
To usemolgen
, you will need an environment with the following packages:
- Python 3.7+
- PyTorch
- PyTorch Lightning
- Einops
For running and visualizing examples:
Once you have these packages installed, you can installmolgen
in the same environment using
$ pip install -e .
Once installed, you can use the package. This example trains a WGANGP to reproduce the alanine dipeptide backbone atoms conditioned on the backbone diedral angles (examples
directory.
frommolgen.modelsimportWGANGPfrompathlibimportPathimportmdtajasmdimporttorchimportnumpyasnp# load datapdb_fname='examples/data/alanine-dipeptide-nowater.pdb'trj_fnames= [str(i)foriinPath('examples/data/').glob('alanine-dipeptide-*-250ns-nowater.xtc')]trjs= [md.load(t,top=pdb_fname).center_coordinates()fortintrj_fnames]# process xyz coordinates and conditioning variablesxyz=list()phi_psi=list()fortrjintrjs:t_backbone=trj.atom_slice(trj.top.select('backbone')).center_coordinates()n=trj.xyz.shape[0]_,phi=md.compute_phi(trj)_,psi=md.compute_psi(trj)xyz.append(torch.tensor(t_backbone.xyz.reshape(n,-1)).float())phi_psi.append(torch.tensor(np.concatenate((phi,psi),-1)).float())# ininstantiate the modelmodel=WGANGP(xyz[0].shape[1],phi_psi[0].shape[1])# fit the modelmodel.fit(xyz,phi_psi,max_epochs=25)# Generate synthetic configurationsxyz_gen=model.generate(torch.cat(phi_psi))xyz_gen=xyz_gen.reshape(xyz_gen.size(0),-1,3)# Save model checkpointmodel.save('ADP.ckpt')# Load from checkpointmodel=WGANGP.load_from_checkpoint('ADP.ckpt')
Supports both generators based on both Generative Adversarial Networks (GANs) and Denoising Diffusion Probabilistic Models (DDPMs). The example above uses GANs, DDPMs support an equivalent API -- for example,
frommolgen.modelsimportDDPMmodel=DDPM(....)
Code for the DDPM models are taken from:https://github.com/lucidrains/denoising-diffusion-pytorch (version 1.0.5)
Copyright (c) 2023, Kirill Shmilovich
Project based on theComputational Molecular Science Python Cookiecutter version 1.1.