Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

GraphAny: Fully-inductive Node Classification on Arbitrary Graphs

License

NotificationsYou must be signed in to change notification settings

DeepGraphLearning/GraphAny

Repository files navigation

Original PyTorch implementation ofGraphAny.

Authored byJianan Zhao,Zhaocheng Zhu,Mikhail Galkin,Hesham Mostafa,Michael Bronstein,andJian Tang.

Overview

Fully-Inductive Model on Node Classification

GraphAny is a fully-inductive model for node classification. A single trained GraphAnymodel performs node classification tasks on any graph with any feature and labelspaces. Performance-wise, averaged on 30+ graphs, a single trained GraphAny modelin inference modeis better than many transductive (supervised) models (e.g., MLP, GCN, and GAT)trained specifically for each graph. Following the pretrain-inference paradigm offoundation models, you can perform training from scratch and inference on 30 datasetsas shown inTraining from scratch.

This repository is based on PyTorch 2.1, Pytorch-Lightning 2.2, PyG 2.4, DGL 2.1, and Hydra 1.3.

Environment Setup

Our experiments are designed to run on both GPU and CPU platforms. A GPU with 16 GBof memory is sufficient to handle all 31 datasets, and we have also tested the setupon a single CPU (specifically, an M1 MacBook).

To configure your environment, use the following commands based on your setup:

# For setups with a GPU (requires CUDA 11.8):conda env create -f environment.yaml# For setups using a CPU (tested on macOS with M1 chip):conda env create -f environment_cpu.yaml

File Structure

├── README.md├── checkpoints├── configs│   ├── data.yaml│   ├── main.yaml│   └── model.yaml├── environment.yaml├── environment_cpu.yaml└── graphany    ├── __init__.py    ├── data.py    ├── model.py    ├── run.py    └── utils

Reproduce Our Results

Training GraphAny from Scratch

This section would detail how users can train GraphAny on one dataset (Cora,Wisconsin, Arxiv, or Product) and evaluate on all 31 datasets. You can reproduceour results via the commands below. The checkpoints of these commands are saved inthecheckpoints/ folder.

cd path/to/this/repo# Reproduce GraphAny-Cora: test_acc= 66.98 for seed 0python graphany/run.py dataset=CoraXAll total_steps=500 n_hidden=64 n_mlp_layer=1 entropy=2 n_per_label_examples=5# Reproduce GraphAny-Wisconsin: test_acc= 67.36 for seed 0python graphany/run.py dataset=WisXAll total_steps=1000 n_hidden=32 n_mlp_layer=2 entropy=1 n_per_label_examples=5# Reproduce GraphAny-Arxiv: test_acc=67.58 for seed 0python graphany/run.py dataset=ArxivXAll total_steps=1000 n_hidden=128 n_mlp_layer=2 entropy=1 n_per_label_examples=3# Reproduce GraphAny-Product: test_acc=67.77 for seed 0python graphany/run.py dataset=ProdXAll total_steps=1000 n_hidden=128 n_mlp_layer=2 entropy=1 n_per_label_examples=3

Inference Using Pre-trained Checkpoints

Once trained, GraphAny enjoys the ability to perform inference on any graph. Youcan use our trained checkpoint to run inference on your graph easily. Here, weshowcase an example of loading a GraphAny model trained on Arxiv and performinference on Cora and Citeseer.

Step 1: Define your custom combined dataset config in theconfigs/data.yaml :

# configs/data.yaml_dataset_lookup:# Train on Arxiv, inference on Cora and CiteseerCoraCiteInference:train:[ Arxiv ]eval:[ Cora, Citeseer ]

Step 2(optional): Define your dataset processing logic in graph_any/data.py.This step is necessary only if you are not using our pre-processed data. If youchoose to use our provided datasets, you can skip this step and proceed directly toStep 3.

Step 3: Inference using pre-trained model using command:

python graphany/run.py prev_ckpt=checkpoints/graph_any_arxiv.pt total_steps=0 dataset=CoraCiteInference# ind/cora_test_acc 79.4 ind/cite_test_acc 68.4
Example Output Log
# Training LogsCRITICAL {'ind/cora_val_acc': 75.4,             'ind/cite_val_acc': 70.4,             'val_acc': 72.9,                      'trans_val_acc': nan,  # Not applicable as Arxiv is not included in the evaluation set             'ind_val_acc': 72.9,                  'heldout_val_acc': 70.4,              'ind/cora_test_acc': 79.4,            'ind/cite_test_acc': 68.4,            'test_acc': 73.9,                     'trans_test_acc': nan,                'ind_test_acc': 73.9,                 'heldout_test_acc': 68.4              }    INFO Finished main at 06-01 05:07:49, running time = 2.52s.

Note: Thetrans_test_acc field is not applicable since Arxiv is not specified inthe evaluation datasets. Additionally, the heldout accuracies are calculated byexcluding datasets specified as transductive inconfigs/data.yaml (defaultsettings:_trans_datasets: [Arxiv, Product, Cora, Wisconsin]). To utilize the heldoutmetrics correctly, please adjust these transductive datasets in your configurationto reflect your specific dataset inductive split settings.

Configuration Details

We useHydra to manage the configuration. Theconfigs are organized in three files under theconfigs/ directory:

main.yaml

Settings for experiments, including random seed, wandb, path,hydra, and logging configs.

data.yaml

This file contains settings for datasets, including preprocessing specifications,metadata, and lookup configurations. Here’s an overview of the key elements:

Details

Dataset Preprocessing Options

  • preprocess_device: gpu — Specifies the device for computing propagated features$\boldsymbol{F}$. Set to cpu if your GPU memory is below 32GB.
  • add_self_loop: false — Specifies whether to add self-loops to the nodes in thegraph.
  • to_bidirected: true — If set to true, edges are made bidirectional.
  • n_hops: 2 — Defines the maximum number of hops of message passing. In ourexperiments, besides Linear, we use LinearSGC1, LinearSGC1, LinearHGC1,LinearHGC2, which predicts information within 2 hops of message passing.

Train and Evaluation Dataset Lookup

  • The datasets for training and evaluation are dynamically selected based on thecommand-line arguments by looking up from the_dataset_lookup configuration
  • Example: Usingdataset=CoraXAll setstrain_datasets to[Cora] andeval_datasets to all datasets (31 in total).
train_datasets:${oc.select:_dataset_lookup.${dataset}.train,${dataset}}eval_datasets:${oc.select:_dataset_lookup.${dataset}.eval,${dataset}}_dataset_lookup:-CoraXAll:  -train:[Cora]  -eval:${_all_datasets}

Please define your own dataset combinations in_dataset_lookup if desired.

Detailed Dataset Configurations

The dataset meta-data stores the meta information including the interfacesDGL,PyG,OGB,Heterophilous and their aliases (e.g.Planetoid.Cora) to load thedataset. The statistics are provided in the comment with a format of 'n_nodes,n_edges, n_feat_dim, n_labels'. For example:

_ds_meta_data:Arxiv:ogb, ogbn-arxiv# 168,343 1,166,243 100 40Cora:pyg, Planetoid.Cora# 2,708 10,556 1,433 7

model.yaml

This file contains the settings for models and training.

Details

GraphAny leveragesinteractions between predictions as input features for anMLP to calculate inductive attention scores. These inputs are termed "featurechannels" and are defined in the configuration file asfeat_chn. Subsequently,the outputs from LinearGNNs, referred to as "prediction channels", arecombined using inductive attention scores and are defined aspred_chn in theconfiguration file. The default settings are:

feat_chn:X+L1+L2+H1+H2# X=Linear, L1=LinearSGC1, L2=LinearSGC2, H1=LinearHGC1, H2=LinearHGC2pred_chn:X+L1+L2# H1 and H2 channels are masked to enhance convergence speed.

It is important to note that the feature channels and prediction channels do notneed to be identical. Empirical observations indicate that masking LinearHGC1 andLinearHGC2 leads to faster convergence and marginally improved results (results inTable 2, Figure 1, and Figure 5). Furthermore, for the attention visualizations inFigure 6, all five channels (pred_chn=X+L1+L2+H1+H2) are employed. Thisdemonstrates GraphAny's capability to learn inductive attention that effectivelyidentifies critical channels for unseen graphs.

Other model parameters and default values:

# The entropy to normalize the distance features (conditional gaussian distribution). The standard deviation of conditional gaussian distribution is dynamically determined via binary search, default to 1entropy:1attn_temp:5# The temperature for attention normalizationn_hidden:128# The hidden dimension of MLPn_mlp_layer:2

Bring Your Own Dataset

We support three major sources of graph dataset interfaces:DGL,PyG, andOGB.If you are interested in adding your own dataset, here's how we integrated the cleanedTexas dataset processed bythis paper.The original Texas dataset contains 5 classes, with a class with only one node,which makes using this class for training and evaluation meaningless.

In the example below, we demonstrate how to add a dataset called "Texas" with 4classes from a new data source termedheterophilous.

Step 1: Updateconfigs/data.yaml:

First, define your dataset's metadata.

# configs/data.yaml_ds_meta_data:# key: dataset name, value: data_source, aliasTexas:heterophilous, texas_4_classes

Thedata_source is set as 'heterophilous', which is handled differently from othersources ('pyg', 'dgl', 'ogb').

Additionally, update the_dataset_lookup with a new setting:

# configs/data.yaml_dataset_lookup:Debug:train:[ Wisconsin ]eval:[ Texas ]

Step 2: Implement the dataset interface:

Implementload_heterophilous_dataset indata.py to download and process the dataset.

importnumpyasnpimporttorchfromgraphany.dataimportdownload_urlimportdgldefload_heterophilous_dataset(url,raw_dir):# Converts Heterophilous dataset to DGL Graph formatdownload_path=download_url(url,raw_dir)data=np.load(download_path)node_features=torch.tensor(data['node_features'])labels=torch.tensor(data['node_labels'])edges=torch.tensor(data['edges'])graph=dgl.graph((edges[:,0],edges[:,1]),num_nodes=len(node_features),idtype=torch.int32)num_classes=len(labels.unique())train_mask,val_mask,test_mask=torch.tensor(data['train_mask']),torch.tensor(data['val_mask']),torch.tensor(data['test_mask'])returngraph,labels,num_classes,node_features,train_mask,val_mask,test_mask

Step 3: UpdateGraphDataset class indata.py:

Modify the initialization and dataset loading functions:

# In GraphDataset.__init__():ifself.data_sourcein ['dgl','pyg','ogb']:pass# Code for other data sources omitted for brevityelifself.data_source=='heterophilous':target='.data.load_heterophilous_dataset'url=f'https://example.com/data/{ds_alias}.npz'ds_init_args= {"_target_":target,'raw_dir':f'{cfg.dirs.data_storage}{self.data_source}/','url':url    }else:raiseNotImplementedError(f'Unsupported data source:{self.data_source}')# In GraphDataset.load_dataset():fromhydra.utilsimportinstantiatedefload_dataset(self,data_init_args):dataset=instantiate(data_init_args)ifself.data_sourcein ['dgl','pyg','ogb']:pass# Code for other data sources omitted for brevityelifself.data_source=='heterophilous':g,label,num_class,feat,train_mask,val_mask,test_mask=dataset# Rest of the code omitted for brevity

You can now run the code using the following commands:

# Training from scratchpython graphany/run.py dataset=Debug total_steps=500# Inference using existing checkpointpython graphany/run.py prev_ckpt=checkpoints/graph_any_wisconsin.pt dataset=Debug total_steps=0

Using Wandb for Enhanced Visualization

We recommend usingWeights & Biases (wandb) for advancedvisualization capabilities. As an example, consider the visualizations for theGraphAny-Arxiv project shown below, which illustrate the validation accuracy acrossdifferent data set categories:

  • Transductive: Training dataset (i.e. Arxiv)
  • Heldout: 27 datasets (except Cora, Wisconsin, Arxiv, Product)
  • Inductive: 30 datasets (except arxiv)
  • Overall: 31 datasets (all datasets)

wandb_training_curve

By default, wandb integration is disabled. To enable and configure wandb for yourproject, use the following command, substitutingYourOwnWandbEntity with youractual Weights & Biases entity name:

use_wandb=true wandb_proj=GraphAny wandb_entity=YourOwnWandbEntity

This setup will allow you to track and visualize metrics dynamically.

Citation

If you find this codebase useful in your research, please cite the paper.

@article{zhao2025graphany,title ={Fully-inductive Node Classification on Arbitrary Graphs},author ={Jianan Zhao and Zhaocheng Zhu and Mikhail Galkin and Hesham Mostafa and Michael Bronstein and Jian Tang},journal ={International Conference on Learning Representations},year ={2025}}

About

GraphAny: Fully-inductive Node Classification on Arbitrary Graphs

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors2

  •  
  •  

Languages


[8]ページ先頭

©2009-2025 Movatter.jp