Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch

License

NotificationsYou must be signed in to change notification settings

lucidrains/egnn-pytorch

Repository files navigation

** A bug has been discovered with the neighbor selection in the presence of masking. If you ran any experiments prior to 0.1.12 that had masking, please rerun them. 🙏 **

EGNN - Pytorch

Implementation ofE(n)-Equivariant Graph Neural Networks, in Pytorch. May be eventually used for Alphafold2 replication. This technique went for simple invariant features, and ended up beating all previous methods (including SE3 Transformer and Lie Conv) in both accuracy and performance. SOTA in dynamical system models, molecular activity prediction tasks, etc.

Install

$ pip install egnn-pytorch

Usage

importtorchfromegnn_pytorchimportEGNNlayer1=EGNN(dim=512)layer2=EGNN(dim=512)feats=torch.randn(1,16,512)coors=torch.randn(1,16,3)feats,coors=layer1(feats,coors)feats,coors=layer2(feats,coors)# (1, 16, 512), (1, 16, 3)

With edges

importtorchfromegnn_pytorchimportEGNNlayer1=EGNN(dim=512,edge_dim=4)layer2=EGNN(dim=512,edge_dim=4)feats=torch.randn(1,16,512)coors=torch.randn(1,16,3)edges=torch.randn(1,16,16,4)feats,coors=layer1(feats,coors,edges)feats,coors=layer2(feats,coors,edges)# (1, 16, 512), (1, 16, 3)

A full EGNN network

importtorchfromegnn_pytorchimportEGNN_Networknet=EGNN_Network(num_tokens=21,num_positions=1024,# unless what you are passing in is an unordered set, set this to the maximum sequence lengthdim=32,depth=3,num_nearest_neighbors=8,coor_weights_clamp_value=2.# absolute clamped value for the coordinate weights, needed if you increase the num neareest neighbors)feats=torch.randint(0,21, (1,1024))# (1, 1024)coors=torch.randn(1,1024,3)# (1, 1024, 3)mask=torch.ones_like(feats).bool()# (1, 1024)feats_out,coors_out=net(feats,coors,mask=mask)# (1, 1024, 32), (1, 1024, 3)

Only attend to sparse neighbors, given to the network as an adjacency matrix.

importtorchfromegnn_pytorchimportEGNN_Networknet=EGNN_Network(num_tokens=21,dim=32,depth=3,only_sparse_neighbors=True)feats=torch.randint(0,21, (1,1024))coors=torch.randn(1,1024,3)mask=torch.ones_like(feats).bool()# naive adjacency matrix# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)i=torch.arange(1024)adj_mat= (i[:,None]>= (i[None, :]-1))& (i[:,None]<= (i[None, :]+1))feats_out,coors_out=net(feats,coors,mask=mask,adj_mat=adj_mat)# (1, 1024, 32), (1, 1024, 3)

You can also have the network automatically determine the Nth-order neighbors, and pass in an adjacency embedding (depending on the order) to be used as an edge, with two extra keyword arguments

importtorchfromegnn_pytorchimportEGNN_Networknet=EGNN_Network(num_tokens=21,dim=32,depth=3,num_adj_degrees=3,# fetch up to 3rd degree neighborsadj_dim=8,# pass an adjacency degree embedding to the EGNN layer, to be used in the edge MLPonly_sparse_neighbors=True)feats=torch.randint(0,21, (1,1024))coors=torch.randn(1,1024,3)mask=torch.ones_like(feats).bool()# naive adjacency matrix# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)i=torch.arange(1024)adj_mat= (i[:,None]>= (i[None, :]-1))& (i[:,None]<= (i[None, :]+1))feats_out,coors_out=net(feats,coors,mask=mask,adj_mat=adj_mat)# (1, 1024, 32), (1, 1024, 3)

Edges

If you need to pass in continuous edges

importtorchfromegnn_pytorchimportEGNN_Networknet=EGNN_Network(num_tokens=21,dim=32,depth=3,edge_dim=4,num_nearest_neighbors=3)feats=torch.randint(0,21, (1,1024))coors=torch.randn(1,1024,3)mask=torch.ones_like(feats).bool()continuous_edges=torch.randn(1,1024,1024,4)# naive adjacency matrix# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)i=torch.arange(1024)adj_mat= (i[:,None]>= (i[None, :]-1))& (i[:,None]<= (i[None, :]+1))feats_out,coors_out=net(feats,coors,edges=continuous_edges,mask=mask,adj_mat=adj_mat)# (1, 1024, 32), (1, 1024, 3)

Stability

The initial architecture for EGNN suffered from instability when there was high number of neighbors. Thankfully, there seems to be two solutions that largely mitigate this.

importtorchfromegnn_pytorchimportEGNN_Networknet=EGNN_Network(num_tokens=21,dim=32,depth=3,num_nearest_neighbors=32,norm_coors=True,# normalize the relative coordinatescoor_weights_clamp_value=2.# absolute clamped value for the coordinate weights, needed if you increase the num neareest neighbors)feats=torch.randint(0,21, (1,1024))# (1, 1024)coors=torch.randn(1,1024,3)# (1, 1024, 3)mask=torch.ones_like(feats).bool()# (1, 1024)feats_out,coors_out=net(feats,coors,mask=mask)# (1, 1024, 32), (1, 1024, 3)

All parameters

importtorchfromegnn_pytorchimportEGNNmodel=EGNN(dim=dim,# input dimensionedge_dim=0,# dimension of the edges, if exists, should be > 0m_dim=16,# hidden model dimensionfourier_features=0,# number of fourier features for encoding of relative distance - defaults to none as in papernum_nearest_neighbors=0,# cap the number of neighbors doing message passing by relative distancedropout=0.0,# dropoutnorm_feats=False,# whether to layernorm the featuresnorm_coors=False,# whether to normalize the coordinates, using a strategy from the SE(3) Transformers paperupdate_feats=True,# whether to update features - you can build a layer that only updates one or the otherupdate_coors=True,# whether ot update coordinatesonly_sparse_neighbors=False,# using this would only allow message passing along adjacent neighbors, using the adjacency matrix passed invalid_radius=float('inf'),# the valid radius each node considers for message passingm_pool_method='sum',# whether to mean or sum pool for output node representationsoft_edges=False,# extra GLU on the edges, purportedly helps stabilize the network in updated version of the papercoor_weights_clamp_value=None# clamping of the coordinate updates, again, for stabilization purposes)

Examples

To run the protein backbone denoising example, first installsidechainnet

$ pip install sidechainnet

Then

$ python denoise_sparse.py

Tests

Make sure you have pytorch geometric installed locally

$ python setup.pytest

Citations

@misc{satorras2021en,title ={E(n) Equivariant Graph Neural Networks},author ={Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},year ={2021},eprint ={2102.09844},archivePrefix ={arXiv},primaryClass ={cs.LG}}

About

Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Contributors7

Languages


[8]ページ先頭

©2009-2025 Movatter.jp