- Notifications
You must be signed in to change notification settings - Fork1
PyTorch implementation of Sparse Function-space Representation of Neural Networks
License
AaltoML/sfr
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
This repository contains a clean and minimal PyTorch implementation of Sparse Function-space Representation (SFR) of Neural Networks.If you'd like to use SFR we recommend using this repo.Please seesfr-experiments for reproducing the experiments in the ICLR 2024 paper.
Create an environment with:
conda env create -f env_cpu.yaml
Activate the environment with:
source activate sfr
Create an environment with:
conda env create -f env_nvidia.yaml
Activate the environment with:
source activate sfr
See thenotebooks for how to use our code for both regression and classification.
We provide a minimal training script intrain.py which can be used to train a CNN and fitSFR
on MNIST/Fashion-MNIST/CIFAR-10. It is advised to run this on GPU.
Here's a short example:
importsrcimporttorchtorch.set_default_dtype(torch.float64)deffunc(x,noise=True):returntorch.sin(x*5)/x+torch.cos(x*10)# Toy data setX_train=torch.rand((100,1))*2Y_train=func(X_train,noise=True)data= (X_train,Y_train)# Training configwidth=64num_epochs=1000batch_size=16learning_rate=1e-3delta=0.00005# prior precisiondata_loader=torch.utils.data.DataLoader(torch.utils.data.TensorDataset(*data),batch_size=batch_size)# Create a neural networknetwork=torch.nn.Sequential(torch.nn.Linear(1,width),torch.nn.Tanh(),torch.nn.Linear(width,width),torch.nn.Tanh(),torch.nn.Linear(width,1),)# Instantiate SFR (handles NN training/prediction as they're coupled via the prior/likelihood)sfr=src.SFR(network=network,prior=src.priors.Gaussian(params=network.parameters,delta=delta),likelihood=src.likelihoods.Gaussian(sigma_noise=2),output_dim=1,num_inducing=32,dual_batch_size=None,# this reduces the memory required for computing dual parametersjitter=1e-4,)sfr.train()optimizer=torch.optim.Adam([{"params":sfr.parameters()}],lr=learning_rate)forepoch_idxinrange(num_epochs):forbatch_idx,batchinenumerate(data_loader):x,y=batchloss=sfr.loss(x,y)optimizer.zero_grad()loss.backward()optimizer.step()sfr.set_data(data)# This builds the dual parameters# Make predictions in function spaceX_test=torch.linspace(-0.7,3.5,300,dtype=torch.float64).reshape(-1,1)f_mean,f_var=sfr.predict_f(X_test)# Make predictions in output spacey_mean,y_var=sfr.predict(X_test)
Set uppre-commit by running:
pre-commit install
Now when you commit the formatter/linter etc will automatically be run.
Please consider citing our conference paper
@inproceedings{scannellFunction2024,title ={Function-space Prameterization of Neural Networks for Sequential Learning},booktitle ={Proceedings of The Twelth International Conference on Learning Representations (ICLR 2024)},author ={Aidan Scannell and Riccardo Mereu and Paul Chang and Ella Tami and Joni Pajarinen and Arno Solin},year ={2024},month ={5},}
Or our workshop
@inproceedings{scannellSparse2023,title ={Sparse Function-space Representation of Neural Networks},maintitle ={ICML 2023 Workshop on Duality Principles for Modern Machine Learning},author ={Aidan Scannell and Riccardo Mereu and Paul Chang and Ella Tami and Joni Pajarinen and Arno Solin},year ={2023},month ={7},}
About
PyTorch implementation of Sparse Function-space Representation of Neural Networks
Topics
Resources
License
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Releases
Packages0
Uh oh!
There was an error while loading.Please reload this page.
Contributors3
Uh oh!
There was an error while loading.Please reload this page.