Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

PyTorch implementation of Sparse Function-space Representation of Neural Networks

License

NotificationsYou must be signed in to change notification settings

AaltoML/sfr

Repository files navigation

This repository contains a clean and minimal PyTorch implementation of Sparse Function-space Representation (SFR) of Neural Networks.If you'd like to use SFR we recommend using this repo.Please seesfr-experiments for reproducing the experiments in the ICLR 2024 paper.

Function-space Parameterization of Neural Networks for Sequential Learning
Aidan Scannell*, Riccardo Mereu*, Paul Chang, Ella Tamir, Joni Pajarinen, Arno Solin
International Conference on Learning Representations (ICLR 2024)
PaperCodeWebsite
Sparse Function-space Representation of Neural Networks
Aidan Scannell*, Riccardo Mereu*, Paul Chang, Ella Tamir, Joni Pajarinen, Arno Solin
ICML 2023 Workshop on Duality Principles for Modern Machine Learning
PaperCodeWebsite

Install

CPU

Create an environment with:

conda env create -f env_cpu.yaml

Activate the environment with:

source activate sfr

NVIDIA GPU

Create an environment with:

conda env create -f env_nvidia.yaml

Activate the environment with:

source activate sfr

Useage

See thenotebooks for how to use our code for both regression and classification.

Image Classification

We provide a minimal training script intrain.py which can be used to train a CNN and fitSFRon MNIST/Fashion-MNIST/CIFAR-10. It is advised to run this on GPU.

Example

Here's a short example:

importsrcimporttorchtorch.set_default_dtype(torch.float64)deffunc(x,noise=True):returntorch.sin(x*5)/x+torch.cos(x*10)# Toy data setX_train=torch.rand((100,1))*2Y_train=func(X_train,noise=True)data= (X_train,Y_train)# Training configwidth=64num_epochs=1000batch_size=16learning_rate=1e-3delta=0.00005# prior precisiondata_loader=torch.utils.data.DataLoader(torch.utils.data.TensorDataset(*data),batch_size=batch_size)# Create a neural networknetwork=torch.nn.Sequential(torch.nn.Linear(1,width),torch.nn.Tanh(),torch.nn.Linear(width,width),torch.nn.Tanh(),torch.nn.Linear(width,1),)# Instantiate SFR (handles NN training/prediction as they're coupled via the prior/likelihood)sfr=src.SFR(network=network,prior=src.priors.Gaussian(params=network.parameters,delta=delta),likelihood=src.likelihoods.Gaussian(sigma_noise=2),output_dim=1,num_inducing=32,dual_batch_size=None,# this reduces the memory required for computing dual parametersjitter=1e-4,)sfr.train()optimizer=torch.optim.Adam([{"params":sfr.parameters()}],lr=learning_rate)forepoch_idxinrange(num_epochs):forbatch_idx,batchinenumerate(data_loader):x,y=batchloss=sfr.loss(x,y)optimizer.zero_grad()loss.backward()optimizer.step()sfr.set_data(data)# This builds the dual parameters# Make predictions in function spaceX_test=torch.linspace(-0.7,3.5,300,dtype=torch.float64).reshape(-1,1)f_mean,f_var=sfr.predict_f(X_test)# Make predictions in output spacey_mean,y_var=sfr.predict(X_test)

Development

Set uppre-commit by running:

pre-commit install

Now when you commit the formatter/linter etc will automatically be run.

Citation

Please consider citing our conference paper

@inproceedings{scannellFunction2024,title           ={Function-space Prameterization of Neural Networks for Sequential Learning},booktitle       ={Proceedings of The Twelth International Conference on Learning Representations (ICLR 2024)},author          ={Aidan Scannell and Riccardo Mereu and Paul Chang and Ella Tami and Joni Pajarinen and Arno Solin},year            ={2024},month           ={5},}

Or our workshop

@inproceedings{scannellSparse2023,title           ={Sparse Function-space Representation of Neural Networks},maintitle       ={ICML 2023 Workshop on Duality Principles for Modern Machine Learning},author          ={Aidan Scannell and Riccardo Mereu and Paul Chang and Ella Tami and Joni Pajarinen and Arno Solin},year            ={2023},month           ={7},}

Releases

No releases published

Packages

No packages published

[8]ページ先頭

©2009-2025 Movatter.jp