Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Graph convolutions in Keras with TensorFlow, PyTorch or Jax.

License

NotificationsYou must be signed in to change notification settings

aimat-lab/gcnn_keras

Repository files navigation

GitHub release (latest by date)Documentation StatusPyPI versionPyPI - Downloadskgcnn_unit_testsDOIGitHubGitHub issuesMaintenance

Keras Graph Convolution Neural Networks

General |Requirements |Installation |Documentation |Implementation details|Literature |Data |Datasets |Training |Issues |Citing |References

General

The package inkgcnn contains several layer classes to build up graph convolution models inKeras with Tensorflow, PyTorch or Jax as backend.Some models are given as an example in literature.Adocumentation is generated indocs.Focus ofkgcnn is (batched) graph learning for moleculeskgcnn.molecule and materialskgcnn.crystal.If you want to get in contact, feel free todiscuss.

Note that kgcnn>=4.0.0 requires keras>=3.0.0. Previous versions of kgcnn were focused on ragged tensors of tensorflow, for whichhyperparameter for models should also transfer to kgcnn 4.0 by addinginput_tensor_type: "ragged" and checking the order anddtype of inputs.

Requirements

Standard python package requirements are installed automatically.However, you must make sure to install the GPU/TPU acceleration for the backend of your choice.

Installation

Clonerepository or latestrelease and install with editable mode or latest release viaPython Package Index.

pip install kgcnn

Documentation

Auto-documentation is generated athttps://kgcnn.readthedocs.io/en/latest/index.html .

Implementation details

Representation

A graph ofN nodes andM edges is commonly represented by a list of node or edge attributes:node_attr oredge_attr, respectively.Plus a list of indices pairs(i, j) that represents a directed edge in the graph:edge_index.The feature dimension of the attributes is denoted byF.Alternatively, an adjacency matrixA_ij of shape(N, N) can be ascribed that has 'ones' entrieswhere there is an edge between nodes and 'zeros' elsewhere. Consequently, sum ofA_ij will giveM edges.

Input

For learning on batches or single graphs, following tensor representation can be chosen:

Batched Graphs
  • node_attr: Node attributes of shape(batch, N, F) and dtypefloat
  • edge_attr: Edge attributes of shape(batch, M, F) and dtypefloat
  • edge_index: Indices of shape(batch, M, 2) and dtypeint
  • graph_attr: Graph attributes of shape(batch, F) and dtypefloat

Graphs are stacked along the batch dimensionbatch. Note that for flexible sized graphs the tensor has to be padded up to a maxN/M or ragged tensors are used,with a ragged rank of one.

Disjoint Graphs
  • node_attr: Node attributes of shape([N], F) and dtypefloat
  • edge_attr: Edge attributes of shape([M], F) and dtypefloat
  • edge_index: Indices of shape(2, [M]) and dtypeint
  • batch_ID: Graph ID of shape([N], ) and dtypeint

Here, the lists essentially represent one graph but which consists of disjoint sub-graphs from the batch,which has been introduced by PytorchGeometric (PyG).For pooling, the graph assignment is stored inbatch_ID.Note, that for Jax, we can not have dynamic shapes, so we use a padded disjoint representation assigningall padded nodes to a discarded graph with zero index.

Model

The keras layers inkgcnn.layers can be used with PyG compatible tensor representation.Or even by simply wrapping a PyG model withTorchModuleWrapper. Efficient model loading can be achievedin multiple ways (seekgcnn.io).For most simple keras-like behaviour, the model can fed with batched padded or ragged tensor which are converted to/fromdisjoint representation wrapping the PyG equivalent model.Here an example of a minimal message passing GNN:

importkerasasksfromkgcnn.layers.castingimportCastBatchedIndicesToDisjointfromkgcnn.layers.gatherimportGatherNodesfromkgcnn.layers.poolingimportPoolingNodesfromkgcnn.layers.aggrimportAggregateLocalEdges# Example for padded input.ns=ks.layers.Input(shape=(None,64),dtype="float32",name="node_attributes")e_idx=ks.layers.Input(shape=(None,2),dtype="int64",name="edge_indices")total_n=ks.layers.Input(shape=(),dtype="int64",name="total_nodes")# Or masktotal_e=ks.layers.Input(shape=(),dtype="int64",name="total_edges")# Or maskn,idx,batch_id,_,_,_,_,_=CastBatchedIndicesToDisjoint(uses_mask=False)([ns,e_idx,total_n,total_e])n_in_out=GatherNodes()([n,idx])node_messages=ks.layers.Dense(64,activation='relu')(n_in_out)node_updates=AggregateLocalEdges()([n,node_messages,idx])n_node_updates=ks.layers.Concatenate()([n,node_updates])n_embedding=ks.layers.Dense(1)(n_node_updates)g_embedding=PoolingNodes()([total_n,n_embedding,batch_id])message_passing=ks.models.Model(inputs=[ns,e_idx,total_n,total_e],outputs=g_embedding)

The actual message passing model can further be structured by e.g. subclassing the message passing base layer:

importkerasasksfromkgcnn.layers.messageimportMessagePassingBaseclassMyMessageNN(MessagePassingBase):def__init__(self,units,**kwargs):super(MyMessageNN,self).__init__(**kwargs)self.dense=ks.layers.Dense(units)self.add=ks.layers.Add()defmessage_function(self,inputs,**kwargs):n_in,n_out,edges=inputsreturnself.dense(n_out,**kwargs)defupdate_nodes(self,inputs,**kwargs):nodes,nodes_update=inputsreturnself.add([nodes,nodes_update],**kwargs)

Literature

The following models, proposed in literature, have a module inliterature. The module usually exposes amake_model functionto create akeras.models.Model. The models can but must not be build completely fromkgcnn.layers and can for example includeoriginal implementations (with proper licencing).

... and many more(click to expand).

Data

Data handling classes are given inkgcnn.data which stores graphs asList[Dict] .

Graph dictionary

Graphs are represented by a dictionaryGraphDict of (numpy) arrays which behaves like a pythondict.There are graph pre- and postprocessors inkgcnn.graph which take specific properties by name and apply aprocessing function or transformation.

Important

They can do any operation but note thatGraphDict does not impose an actual graph structure!For example to sort edge indices make sure that all attributes are sorted accordingly.

fromkgcnn.graphimportGraphDict# Single graph.graph=GraphDict({"edge_indices": [[1,0], [0,1]],"node_label": [[0], [1]]})graph.set("graph_labels", [0])# use set(), get() to assign (tensor) properties.graph.set("edge_attributes", [[1.0], [2.0]])graph.to_networkx()# Modify with e.g. preprocessor.fromkgcnn.graph.preprocessorimportSortEdgeIndicesSortEdgeIndices(edge_indices="edge_indices",edge_attributes="^edge_(?!indices$).*",in_place=True)(graph)

List of graph dictionaries

AMemoryGraphList should behave identical to a python list but contain onlyGraphDict items.

fromkgcnn.dataimportMemoryGraphList# List of graph dicts.graph_list=MemoryGraphList([{"edge_indices": [[0,1], [1,0]]}, {"edge_indices": [[0,0]]}, {}])graph_list.clean(["edge_indices"])# Remove graphs without propertygraph_list.get("edge_indices")# opposite is set()# Easily cast to tensor; makes copy.tensor=graph_list.tensor([{"name":"edge_indices"}])# config of keras `Input` layer# Or directly modify list.fori,xinenumerate(graph_list):x.set("graph_number", [i])print(len(graph_list),graph_list[:2])# Also supports indexing lists.

Datasets

TheMemoryGraphDataset inherits fromMemoryGraphList but must be initialized with file information on disk that points to adata_directory for the dataset.Thedata_directory can have a subdirectory for files and/or single file such as a CSV file:

├── data_directory    ├── file_directory    │   ├──*.*    │   └── ...     ├── file_name    └── dataset_name.kgcnn.pickle

A base dataset class is created with path and name information:

fromkgcnn.dataimportMemoryGraphDatasetdataset=MemoryGraphDataset(data_directory="ExampleDir/",dataset_name="Example",file_name=None,file_directory=None)dataset.save()# opposite is load().

The subclassesQMDataset,ForceDataset,MoleculeNetDataset,CrystalDataset andGraphTUDataset further have functions required for the specific dataset type to convert and process files such as '.txt', '.sdf', '.xyz' etc.Most subclasses implementprepare_data() andread_in_memory() with dataset dependent arguments.An example forMoleculeNetDataset is shown below.For more details find tutorials innotebooks.

fromkgcnn.data.moleculenetimportMoleculeNetDataset# File directory and files must exist.# Here 'ExampleDir' and 'ExampleDir/data.csv' with columns "smiles" and "label".dataset=MoleculeNetDataset(dataset_name="Example",data_directory="ExampleDir/",file_name="data.csv")dataset.prepare_data(overwrite=True,smiles_column_name="smiles",add_hydrogen=True,make_conformers=True,optimize_conformer=True,num_workers=None)dataset.read_in_memory(label_column_name="label",add_hydrogen=False,has_conformers=True)

Indata.datasets there are graph learning benchmark datasets as subclasses which are beingdownloaded from e.g. popular graph archives likeTUDatasets,MatBench orMoleculeNet.The subclassesGraphTUDataset2020,MatBenchDataset2020 andMoleculeNetDataset2018 download and read the available datasets by name.There are also specific dataset subclasses for each dataset to handle additional processing or downloading from individual sources:

fromkgcnn.data.datasets.MUTAGDatasetimportMUTAGDatasetdataset=MUTAGDataset()# inherits from GraphTUDataset2020

Downloaded datasets are stored in~/.kgcnn/datasets on your computer. Please remove them manually, if no longer required.

Training

A set of example training can be found intraining. Training scripts are configurable with a hyperparameter config file and command line arguments regarding model and dataset.

You can find atable of common benchmark datasets inresults.

Issues

Some known issues to be aware of, if using and making new models or layers withkgcnn.

  • Jagged or nested Tensors loading into models for PyTorch backend is not working.
  • BatchNormalization layer dos not support padding yet.
  • Keras AUC metric does not seem to work for torch cuda.

Citing

If you want to cite this repo, please refer to ourpaper:

@article{REISER2021100095,title = {Graph neural networks in TensorFlow-Keras with RaggedTensor representation (kgcnn)},journal = {Software Impacts},pages = {100095},year = {2021},issn = {2665-9638},doi = {https://doi.org/10.1016/j.simpa.2021.100095},url = {https://www.sciencedirect.com/science/article/pii/S266596382100035X},author = {Patrick Reiser and Andre Eberhard and Pascal Friederich}}

References


[8]ページ先頭

©2009-2025 Movatter.jp