- Notifications
You must be signed in to change notification settings - Fork1
PyTorch implementation of Sparse Function-space Representation of Neural Networks
License
AaltoML/sfr
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
This repository contains a clean and minimal PyTorch implementation of Sparse Function-space Representation (SFR) of Neural Networks.If you'd like to use SFR we recommend using this repo.Please seesfr-experiments for reproducing the experiments in the ICLR 2024 paper.
Create an environment with:
conda env create -f env_cpu.yaml
Activate the environment with:
source activate sfr
Create an environment with:
conda env create -f env_nvidia.yaml
Activate the environment with:
source activate sfr
See thenotebooks for how to use our code for both regression and classification.
We provide a minimal training script intrain.py which can be used to train a CNN and fitSFR
on MNIST/Fashion-MNIST/CIFAR-10. It is advised to run this on GPU.
Here's a short example:
importsrcimporttorchtorch.set_default_dtype(torch.float64)deffunc(x,noise=True):returntorch.sin(x*5)/x+torch.cos(x*10)# Toy data setX_train=torch.rand((100,1))*2Y_train=func(X_train,noise=True)data= (X_train,Y_train)# Training configwidth=64num_epochs=1000batch_size=16learning_rate=1e-3delta=0.00005# prior precisiondata_loader=torch.utils.data.DataLoader(torch.utils.data.TensorDataset(*data),batch_size=batch_size)# Create a neural networknetwork=torch.nn.Sequential(torch.nn.Linear(1,width),torch.nn.Tanh(),torch.nn.Linear(width,width),torch.nn.Tanh(),torch.nn.Linear(width,1),)# Instantiate SFR (handles NN training/prediction as they're coupled via the prior/likelihood)sfr=src.SFR(network=network,prior=src.priors.Gaussian(params=network.parameters,delta=delta),likelihood=src.likelihoods.Gaussian(sigma_noise=2),output_dim=1,num_inducing=32,dual_batch_size=None,# this reduces the memory required for computing dual parametersjitter=1e-4,)sfr.train()optimizer=torch.optim.Adam([{"params":sfr.parameters()}],lr=learning_rate)forepoch_idxinrange(num_epochs):forbatch_idx,batchinenumerate(data_loader):x,y=batchloss=sfr.loss(x,y)optimizer.zero_grad()loss.backward()optimizer.step()sfr.set_data(data)# This builds the dual parameters# Make predictions in function spaceX_test=torch.linspace(-0.7,3.5,300,dtype=torch.float64).reshape(-1,1)f_mean,f_var=sfr.predict_f(X_test)# Make predictions in output spacey_mean,y_var=sfr.predict(X_test)
Set uppre-commit by running:
pre-commit install
Now when you commit the formatter/linter etc will automatically be run.
Please consider citing our conference paper
@inproceedings{scannellFunction2024,title ={Function-space Prameterization of Neural Networks for Sequential Learning},booktitle ={Proceedings of The Twelth International Conference on Learning Representations (ICLR 2024)},author ={Aidan Scannell and Riccardo Mereu and Paul Chang and Ella Tami and Joni Pajarinen and Arno Solin},year ={2024},month ={5},}
Or our workshop
@inproceedings{scannellSparse2023,title ={Sparse Function-space Representation of Neural Networks},maintitle ={ICML 2023 Workshop on Duality Principles for Modern Machine Learning},author ={Aidan Scannell and Riccardo Mereu and Paul Chang and Ella Tami and Joni Pajarinen and Arno Solin},year ={2023},month ={7},}