- Notifications
You must be signed in to change notification settings - Fork30
Graph convolutions in Keras with TensorFlow, PyTorch or Jax.
License
aimat-lab/gcnn_keras
Folders and files
| Name | Name | Last commit message | Last commit date | |
|---|---|---|---|---|
Repository files navigation
General |Requirements |Installation |Documentation |Implementation details|Literature |Data |Datasets |Training |Issues |Citing |References
The package inkgcnn contains several layer classes to build up graph convolution models inKeras with Tensorflow, PyTorch or Jax as backend.Some models are given as an example in literature.Adocumentation is generated indocs.Focus ofkgcnn is (batched) graph learning for moleculeskgcnn.molecule and materialskgcnn.crystal.If you want to get in contact, feel free todiscuss.
Note that kgcnn>=4.0.0 requires keras>=3.0.0. Previous versions of kgcnn were focused on ragged tensors of tensorflow, for whichhyperparameter for models should also transfer to kgcnn 4.0 by addinginput_tensor_type: "ragged" and checking the order anddtype of inputs.
Standard python package requirements are installed automatically.However, you must make sure to install the GPU/TPU acceleration for the backend of your choice.
Clonerepository or latestrelease and install with editable mode or latest release viaPython Package Index.
pip install kgcnn
Auto-documentation is generated athttps://kgcnn.readthedocs.io/en/latest/index.html .
A graph ofN nodes andM edges is commonly represented by a list of node or edge attributes:node_attr oredge_attr, respectively.Plus a list of indices pairs(i, j) that represents a directed edge in the graph:edge_index.The feature dimension of the attributes is denoted byF.Alternatively, an adjacency matrixA_ij of shape(N, N) can be ascribed that has 'ones' entrieswhere there is an edge between nodes and 'zeros' elsewhere. Consequently, sum ofA_ij will giveM edges.
For learning on batches or single graphs, following tensor representation can be chosen:
node_attr: Node attributes of shape(batch, N, F)and dtypefloatedge_attr: Edge attributes of shape(batch, M, F)and dtypefloatedge_index: Indices of shape(batch, M, 2)and dtypeintgraph_attr: Graph attributes of shape(batch, F)and dtypefloat
Graphs are stacked along the batch dimensionbatch. Note that for flexible sized graphs the tensor has to be padded up to a maxN/M or ragged tensors are used,with a ragged rank of one.
node_attr: Node attributes of shape([N], F)and dtypefloatedge_attr: Edge attributes of shape([M], F)and dtypefloatedge_index: Indices of shape(2, [M])and dtypeintbatch_ID: Graph ID of shape([N], )and dtypeint
Here, the lists essentially represent one graph but which consists of disjoint sub-graphs from the batch,which has been introduced by PytorchGeometric (PyG).For pooling, the graph assignment is stored inbatch_ID.Note, that for Jax, we can not have dynamic shapes, so we use a padded disjoint representation assigningall padded nodes to a discarded graph with zero index.
The keras layers inkgcnn.layers can be used with PyG compatible tensor representation.Or even by simply wrapping a PyG model withTorchModuleWrapper. Efficient model loading can be achievedin multiple ways (seekgcnn.io).For most simple keras-like behaviour, the model can fed with batched padded or ragged tensor which are converted to/fromdisjoint representation wrapping the PyG equivalent model.Here an example of a minimal message passing GNN:
importkerasasksfromkgcnn.layers.castingimportCastBatchedIndicesToDisjointfromkgcnn.layers.gatherimportGatherNodesfromkgcnn.layers.poolingimportPoolingNodesfromkgcnn.layers.aggrimportAggregateLocalEdges# Example for padded input.ns=ks.layers.Input(shape=(None,64),dtype="float32",name="node_attributes")e_idx=ks.layers.Input(shape=(None,2),dtype="int64",name="edge_indices")total_n=ks.layers.Input(shape=(),dtype="int64",name="total_nodes")# Or masktotal_e=ks.layers.Input(shape=(),dtype="int64",name="total_edges")# Or maskn,idx,batch_id,_,_,_,_,_=CastBatchedIndicesToDisjoint(uses_mask=False)([ns,e_idx,total_n,total_e])n_in_out=GatherNodes()([n,idx])node_messages=ks.layers.Dense(64,activation='relu')(n_in_out)node_updates=AggregateLocalEdges()([n,node_messages,idx])n_node_updates=ks.layers.Concatenate()([n,node_updates])n_embedding=ks.layers.Dense(1)(n_node_updates)g_embedding=PoolingNodes()([total_n,n_embedding,batch_id])message_passing=ks.models.Model(inputs=[ns,e_idx,total_n,total_e],outputs=g_embedding)
The actual message passing model can further be structured by e.g. subclassing the message passing base layer:
importkerasasksfromkgcnn.layers.messageimportMessagePassingBaseclassMyMessageNN(MessagePassingBase):def__init__(self,units,**kwargs):super(MyMessageNN,self).__init__(**kwargs)self.dense=ks.layers.Dense(units)self.add=ks.layers.Add()defmessage_function(self,inputs,**kwargs):n_in,n_out,edges=inputsreturnself.dense(n_out,**kwargs)defupdate_nodes(self,inputs,**kwargs):nodes,nodes_update=inputsreturnself.add([nodes,nodes_update],**kwargs)
The following models, proposed in literature, have a module inliterature. The module usually exposes amake_model functionto create akeras.models.Model. The models can but must not be build completely fromkgcnn.layers and can for example includeoriginal implementations (with proper licencing).
- AttentiveFP:Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism by Xiong et al. (2019)
- CGCNN:Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties by Xie et al. (2018)
- CMPNN:Communicative Representation Learning on Attributed Molecular Graphs by Song et al. (2020)
- DGIN:Improved Lipophilicity and Aqueous Solubility Prediction with Composite Graph Neural Networks by Wieder et al. (2021)
- DimeNetPP:Fast and Uncertainty-Aware Directional Message Passing for Non-Equilibrium Molecules by Klicpera et al. (2020)
- DMPNN:Analyzing Learned Molecular Representations for Property Prediction by Yang et al. (2019)
- EGNN:E(n) Equivariant Graph Neural Networks by Satorras et al. (2021)
- GAT:Graph Attention Networks by Veličković et al. (2018)
... and many more(click to expand).
- GATv2:How Attentive are Graph Attention Networks? by Brody et al. (2021)
- GCN:Semi-Supervised Classification with Graph Convolutional Networks by Kipf et al. (2016)
- GIN:How Powerful are Graph Neural Networks? by Xu et al. (2019)
- GNNExplainer:GNNExplainer: Generating Explanations for Graph Neural Networks by Ying et al. (2019)
- GNNFilm:GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation by Marc Brockschmidt (2020)
- GraphSAGE:Inductive Representation Learning on Large Graphs by Hamilton et al. (2017)
- HamNet:HamNet: Conformation-Guided Molecular Representation with Hamiltonian Neural Networks by Li et al. (2021)
- HDNNP2nd:Atom-centered symmetry functions for constructing high-dimensional neural network potentials by Jörg Behler (2011)
- INorp:Interaction Networks for Learning about Objects,Relations and Physics by Battaglia et al. (2016)
- MAT:Molecule Attention Transformer by Maziarka et al. (2020)
- MEGAN:MEGAN: Multi-explanation Graph Attention Network by Teufel et al. (2023)
- Megnet:Graph Networks as a Universal Machine Learning Framework for Molecules and Crystals by Chen et al. (2019)
- MoGAT:Multi-order graph attention network for water solubility prediction and interpretation by Lee et al. (2023)
- MXMNet:Molecular Mechanics-Driven Graph Neural Network with Multiplex Graph for Molecular Structures by Zhang et al. (2020)
- NMPN:Neural Message Passing for Quantum Chemistry by Gilmer et al. (2017)
- PAiNN:Equivariant message passing for the prediction of tensorial properties and molecular spectra by Schütt et al. (2020)
- RGCN:Modeling Relational Data with Graph Convolutional Networks by Schlichtkrull et al. (2017)
- rGINRandom Features Strengthen Graph Neural Networks by Sato et al. (2020)
- Schnet:SchNet – A deep learning architecture for molecules and materials by Schütt et al. (2017)
Data handling classes are given inkgcnn.data which stores graphs asList[Dict] .
Graphs are represented by a dictionaryGraphDict of (numpy) arrays which behaves like a pythondict.There are graph pre- and postprocessors inkgcnn.graph which take specific properties by name and apply aprocessing function or transformation.
Important
They can do any operation but note thatGraphDict does not impose an actual graph structure!For example to sort edge indices make sure that all attributes are sorted accordingly.
fromkgcnn.graphimportGraphDict# Single graph.graph=GraphDict({"edge_indices": [[1,0], [0,1]],"node_label": [[0], [1]]})graph.set("graph_labels", [0])# use set(), get() to assign (tensor) properties.graph.set("edge_attributes", [[1.0], [2.0]])graph.to_networkx()# Modify with e.g. preprocessor.fromkgcnn.graph.preprocessorimportSortEdgeIndicesSortEdgeIndices(edge_indices="edge_indices",edge_attributes="^edge_(?!indices$).*",in_place=True)(graph)
AMemoryGraphList should behave identical to a python list but contain onlyGraphDict items.
fromkgcnn.dataimportMemoryGraphList# List of graph dicts.graph_list=MemoryGraphList([{"edge_indices": [[0,1], [1,0]]}, {"edge_indices": [[0,0]]}, {}])graph_list.clean(["edge_indices"])# Remove graphs without propertygraph_list.get("edge_indices")# opposite is set()# Easily cast to tensor; makes copy.tensor=graph_list.tensor([{"name":"edge_indices"}])# config of keras `Input` layer# Or directly modify list.fori,xinenumerate(graph_list):x.set("graph_number", [i])print(len(graph_list),graph_list[:2])# Also supports indexing lists.
TheMemoryGraphDataset inherits fromMemoryGraphList but must be initialized with file information on disk that points to adata_directory for the dataset.Thedata_directory can have a subdirectory for files and/or single file such as a CSV file:
├── data_directory ├── file_directory │ ├──*.* │ └── ... ├── file_name └── dataset_name.kgcnn.pickle
A base dataset class is created with path and name information:
fromkgcnn.dataimportMemoryGraphDatasetdataset=MemoryGraphDataset(data_directory="ExampleDir/",dataset_name="Example",file_name=None,file_directory=None)dataset.save()# opposite is load().
The subclassesQMDataset,ForceDataset,MoleculeNetDataset,CrystalDataset andGraphTUDataset further have functions required for the specific dataset type to convert and process files such as '.txt', '.sdf', '.xyz' etc.Most subclasses implementprepare_data() andread_in_memory() with dataset dependent arguments.An example forMoleculeNetDataset is shown below.For more details find tutorials innotebooks.
fromkgcnn.data.moleculenetimportMoleculeNetDataset# File directory and files must exist.# Here 'ExampleDir' and 'ExampleDir/data.csv' with columns "smiles" and "label".dataset=MoleculeNetDataset(dataset_name="Example",data_directory="ExampleDir/",file_name="data.csv")dataset.prepare_data(overwrite=True,smiles_column_name="smiles",add_hydrogen=True,make_conformers=True,optimize_conformer=True,num_workers=None)dataset.read_in_memory(label_column_name="label",add_hydrogen=False,has_conformers=True)
Indata.datasets there are graph learning benchmark datasets as subclasses which are beingdownloaded from e.g. popular graph archives likeTUDatasets,MatBench orMoleculeNet.The subclassesGraphTUDataset2020,MatBenchDataset2020 andMoleculeNetDataset2018 download and read the available datasets by name.There are also specific dataset subclasses for each dataset to handle additional processing or downloading from individual sources:
fromkgcnn.data.datasets.MUTAGDatasetimportMUTAGDatasetdataset=MUTAGDataset()# inherits from GraphTUDataset2020
Downloaded datasets are stored in~/.kgcnn/datasets on your computer. Please remove them manually, if no longer required.
A set of example training can be found intraining. Training scripts are configurable with a hyperparameter config file and command line arguments regarding model and dataset.
You can find atable of common benchmark datasets inresults.
Some known issues to be aware of, if using and making new models or layers withkgcnn.
- Jagged or nested Tensors loading into models for PyTorch backend is not working.
- BatchNormalization layer dos not support padding yet.
- Keras AUC metric does not seem to work for torch cuda.
If you want to cite this repo, please refer to ourpaper:
@article{REISER2021100095,title = {Graph neural networks in TensorFlow-Keras with RaggedTensor representation (kgcnn)},journal = {Software Impacts},pages = {100095},year = {2021},issn = {2665-9638},doi = {https://doi.org/10.1016/j.simpa.2021.100095},url = {https://www.sciencedirect.com/science/article/pii/S266596382100035X},author = {Patrick Reiser and Andre Eberhard and Pascal Friederich}}About
Graph convolutions in Keras with TensorFlow, PyTorch or Jax.
Topics
Resources
License
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Packages0
Contributors6
Uh oh!
There was an error while loading.Please reload this page.