Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Machine learning crate for Rust

License

NotificationsYou must be signed in to change notification settings

maciejkula/rustlearn

Repository files navigation

Circle CICrates.io

A machine learning package for Rust.

For full usage details, see theAPI documentation.

Introduction

This crate contains reasonably effectiveimplementations of a number of common machine learning algorithms.

At the moment,rustlearn uses its own basic dense and sparse array types, but I will be happyto use something more robust once a clear winner in that space emerges.

Features

Matrix primitives

Models

All the models support fitting and prediction on both dense and sparse data, and the implementationsshould be roughly competitive with Pythonsklearn implementations, both in accuracy and performance.

Cross-validation

Metrics

Parallelization

A number of models support both parallel model fitting and prediction.

Model serialization

Model serialization is supported viaserde.

Usingrustlearn

Usage should be straightforward.

  • import the prelude for all the linear algebra primitives and common traits:
use rustlearn::prelude::*;
  • import individual models and utilities from submodules:
use rustlearn::prelude::*;use rustlearn::linear_models::sgdclassifier::Hyperparameters;// more imports

Examples

Logistic regression

use rustlearn::prelude::*;use rustlearn::datasets::iris;use rustlearn::cross_validation::CrossValidation;use rustlearn::linear_models::sgdclassifier::Hyperparameters;use rustlearn::metrics::accuracy_score;let(X, y) = iris::load_data();let num_splits =10;let num_epochs =5;letmut accuracy =0.0;for(train_idx, test_idx)inCrossValidation::new(X.rows(), num_splits){letX_train =X.get_rows(&train_idx);let y_train = y.get_rows(&train_idx);letX_test =X.get_rows(&test_idx);let y_test = y.get_rows(&test_idx);letmut model =Hyperparameters::new(X.cols()).learning_rate(0.5).l2_penalty(0.0).l1_penalty(0.0).one_vs_rest();for _in0..num_epochs{        model.fit(&X_train,&y_train).unwrap();}let prediction = model.predict(&X_test).unwrap();    accuracy +=accuracy_score(&y_test,&prediction);}accuracy /= num_splitsasf32;

Random forest

use rustlearn::prelude::*;use rustlearn::ensemble::random_forest::Hyperparameters;use rustlearn::datasets::iris;use rustlearn::trees::decision_tree;let(data, target) = iris::load_data();letmut tree_params = decision_tree::Hyperparameters::new(data.cols());tree_params.min_samples_split(10).max_features(4);letmut model =Hyperparameters::new(tree_params,10).one_vs_rest();model.fit(&data,&target).unwrap();// Optionally serialize and deserialize the model// let encoded = bincode::serialize(&model).unwrap();// let decoded: OneVsRestWrapper<RandomForest> = bincode::deserialize(&encoded).unwrap();let prediction = model.predict(&data).unwrap();

Contributing

Pull requests are welcome.

To run basic tests, runcargo test.

Runningcargo test --features "all_tests" --release runs all tests, including generated and slow tests.Runningcargo bench --features bench (only on the nightly branch) runs benchmarks.

About

Machine learning crate for Rust

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Contributors6


[8]ページ先頭

©2009-2025 Movatter.jp