Artificial Neural Networks package for R
This package allows to train neural networks for classification andregression tasks, as well as autoencoders for anomaly detection. Severalhelper and plotting functions are included for improved usability andunderstanding what the model does.
ANN2 contains a vectorized neural net implementation in C++ thatfacilitates fast training through mini-batch gradient descent.
ANN2 has the following features:
910f3855841191e84ef3252180273e3864dd7f00
Defining and training a multilayer neural network with ANN2 is doneusing a single function call to either:
neuralnetwork() - for a multilayer neural net forclassification or regression,autoencoder() - for an autoencoding neural network thatis trained to reconstruct its inputs.Below are two examples with minimal code snippets that show how touse these functions.
neuralnetwork()We’ll train a neural network with dimensions 4 x 5 x 5 x 3 on theIris data set that classifies whether each observation (sepal length andwidth and petal length and width measurements for three species offloweres) belongs to class setosa, versicolor or virginica. Thedimensions of the input and output layers are inferred from the data,the hidden layer dimensions are defined by providing a single vectorthat specifies the number of nodes for each hidden layer as argumenthidden.layers.
library(ANN2)# Prepare test and train setsrandom_idx<-sample(1:nrow(iris),size =145)X_train<- iris[random_idx,1:4]y_train<- iris[random_idx,5]X_test<- iris[setdiff(1:nrow(iris), random_idx),1:4]y_test<- iris[setdiff(1:nrow(iris), random_idx),5]# Train neural network on classification taskNN<-neuralnetwork(X = X_train,y = y_train,hidden.layers =c(5,5),optim.type ='adam',n.epochs =5000)# Predict the class for new data pointspredict(NN, X_test)# $predictions# [1] "setosa" "setosa" "setosa" "versicolor" "versicolor"## $probabilities# class_setosa class_versicolor class_virginica# [1,] 0.9998184126 0.0001814204 1.670401e-07# [2,] 0.9998311154 0.0001687264 1.582390e-07# [3,] 0.9998280223 0.0001718171 1.605735e-07# [4,] 0.0001074780 0.9997372337 1.552883e-04# [5,] 0.0001077757 0.9996626441 2.295802e-04# Plot the training and validation lossplot(NN)
You can interact with the resultingANN object usingmethodsplot(),print() andpredict(). Storing and loading the model to/from disk canbe done usingwrite_ANN() andread_ANN(),respectively. Other, more low-level, methods of the C++ module can beaccessed through the$ operator as members of the object,eg.NN$Rcpp_ANN$getParams() for getting the parameters(weight matrices and bias vectors) from the trained model.
autoencoder()Theautoencoder() function allows to train anautoencoding neural network. In the next example we’ll train anautoencoder of dimension 4 x 10 x 2 x 10 x 4 on the USArrests data set.The middle hidden layer acts as a bottleneck that forces the autoencoderto only retain structural variation and discard random variation. Bydenoting data points that are poorly reconstructed (high reconstructionerror) as aberant, we exploit this denoising property for anomalydetection.
library(ANN2)# Prepare test and train setsrandom_idx<-sample(1:nrow(USArrests),size =45)X_train<- USArrests[random_idx,]X_test<- USArrests[setdiff(1:nrow(USArrests), random_idx),]# Define and train autoencoderAE<-autoencoder(X = X_train,hidden.layers =c(10,3,10),loss.type ='pseudo-huber',optim.type ='adam',n.epochs =5000)# Reconstruct test datareconstruct(AE, X_test)# $reconstructed# Murder Assault UrbanPop Rape# [1,] 8.547431 243.85898 75.60763 37.791746# [2,] 12.972505 268.68226 65.40411 29.475545# [3,] 2.107441 78.55883 67.75074 15.040075# [4,] 2.085750 56.76030 55.32376 9.346483# [5,] 12.936357 252.09209 56.24075 24.964715## $anomaly_scores# [1] 398.926431 247.238111 11.613522 0.134633 1029.806121# Plot original points (colored) and reconstructions (grey) for training datareconstruction_plot(AE, X_train)
In the reconstruction plot we see the original points (grey) alongwith their reconstructions (color scale based on reconstruction error),connected to each other by a grey line. The length of the line denotesthe reconstruction error.
You can interact with theANN object that results fromtheautoencoder() function call using various methods,includingplot(),encode(),decode() andreconstruct().
More details on supported arguments toneuralnetwork()andautoencoder(), as well as examples and explanations onusing the helper and plotting functions can be found in the manual.