- Notifications
You must be signed in to change notification settings - Fork54
Machine learning crate for Rust
License
maciejkula/rustlearn
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
A machine learning package for Rust.
For full usage details, see theAPI documentation.
This crate contains reasonably effectiveimplementations of a number of common machine learning algorithms.
At the moment,rustlearn
uses its own basic dense and sparse array types, but I will be happyto use something more robust once a clear winner in that space emerges.
- logistic regression using stochastic gradient descent,
- support vector machines using the
libsvm
library, - decision trees using the CART algorithm,
- random forests using CART decision trees, and
- factorization machines.
All the models support fitting and prediction on both dense and sparse data, and the implementationsshould be roughly competitive with Pythonsklearn
implementations, both in accuracy and performance.
A number of models support both parallel model fitting and prediction.
Model serialization is supported viaserde
.
Usage should be straightforward.
- import the prelude for all the linear algebra primitives and common traits:
use rustlearn::prelude::*;
- import individual models and utilities from submodules:
use rustlearn::prelude::*;use rustlearn::linear_models::sgdclassifier::Hyperparameters;// more imports
use rustlearn::prelude::*;use rustlearn::datasets::iris;use rustlearn::cross_validation::CrossValidation;use rustlearn::linear_models::sgdclassifier::Hyperparameters;use rustlearn::metrics::accuracy_score;let(X, y) = iris::load_data();let num_splits =10;let num_epochs =5;letmut accuracy =0.0;for(train_idx, test_idx)inCrossValidation::new(X.rows(), num_splits){letX_train =X.get_rows(&train_idx);let y_train = y.get_rows(&train_idx);letX_test =X.get_rows(&test_idx);let y_test = y.get_rows(&test_idx);letmut model =Hyperparameters::new(X.cols()).learning_rate(0.5).l2_penalty(0.0).l1_penalty(0.0).one_vs_rest();for _in0..num_epochs{ model.fit(&X_train,&y_train).unwrap();}let prediction = model.predict(&X_test).unwrap(); accuracy +=accuracy_score(&y_test,&prediction);}accuracy /= num_splitsasf32;
use rustlearn::prelude::*;use rustlearn::ensemble::random_forest::Hyperparameters;use rustlearn::datasets::iris;use rustlearn::trees::decision_tree;let(data, target) = iris::load_data();letmut tree_params = decision_tree::Hyperparameters::new(data.cols());tree_params.min_samples_split(10).max_features(4);letmut model =Hyperparameters::new(tree_params,10).one_vs_rest();model.fit(&data,&target).unwrap();// Optionally serialize and deserialize the model// let encoded = bincode::serialize(&model).unwrap();// let decoded: OneVsRestWrapper<RandomForest> = bincode::deserialize(&encoded).unwrap();let prediction = model.predict(&data).unwrap();
Pull requests are welcome.
To run basic tests, runcargo test
.
Runningcargo test --features "all_tests" --release
runs all tests, including generated and slow tests.Runningcargo bench --features bench
(only on the nightly branch) runs benchmarks.
About
Machine learning crate for Rust
Resources
License
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Packages0
Uh oh!
There was an error while loading.Please reload this page.
Contributors6
Uh oh!
There was an error while loading.Please reload this page.