@@ -628,6 +628,207 @@ impl Model {
628
628
629
629
estimator
630
630
}
631
+
632
+ Algorithm :: ridge =>{
633
+ train_test_split ! ( dataset, x_train, y_train) ;
634
+ hyperparam_f32 ! ( alpha, hyperparams, 1.0 ) ;
635
+ hyperparam_bool ! ( normalize, hyperparams, false ) ;
636
+
637
+ let solver =match hyperparams. get ( "solver" ) {
638
+ Some ( solver) =>match solver. as_str ( ) . unwrap_or ( "cholesky" ) {
639
+ "svd" =>{
640
+ smartcore:: linear:: ridge_regression:: RidgeRegressionSolverName :: SVD
641
+ }
642
+ _ =>{
643
+ smartcore:: linear:: ridge_regression:: RidgeRegressionSolverName :: Cholesky
644
+ }
645
+ } ,
646
+ None => smartcore:: linear:: ridge_regression:: RidgeRegressionSolverName :: SVD ,
647
+ } ;
648
+
649
+ let estimator: Option < Box < dyn Estimator > > =match project. task {
650
+ Task :: regression =>Some (
651
+ Box :: new (
652
+ smartcore:: linear:: ridge_regression:: RidgeRegression :: fit (
653
+ & x_train,
654
+ & y_train,
655
+ smartcore:: linear:: ridge_regression:: RidgeRegressionParameters :: default ( )
656
+ . with_alpha ( alpha)
657
+ . with_normalize ( normalize)
658
+ . with_solver ( solver)
659
+ ) . unwrap ( )
660
+ )
661
+ ) ,
662
+
663
+ Task :: classification =>panic ! ( "Ridge does not support classification" ) ,
664
+ } ;
665
+
666
+ save_estimator ! ( estimator, self ) ;
667
+
668
+ estimator
669
+ }
670
+
671
+ Algorithm :: kmeans =>{
672
+ todo ! ( ) ;
673
+ }
674
+
675
+ Algorithm :: dbscan =>{
676
+ todo ! ( ) ;
677
+ }
678
+
679
+ Algorithm :: knn =>{
680
+ train_test_split ! ( dataset, x_train, y_train) ;
681
+ let algorithm =match hyperparams
682
+ . get ( "algorithm" )
683
+ . unwrap_or ( & serde_json:: Value :: from ( "linear_search" ) )
684
+ . as_str ( )
685
+ . unwrap_or ( "linear_search" )
686
+ {
687
+ "cover_tree" => smartcore:: algorithm:: neighbour:: KNNAlgorithmName :: CoverTree ,
688
+ _ => smartcore:: algorithm:: neighbour:: KNNAlgorithmName :: LinearSearch ,
689
+ } ;
690
+ let weight =match hyperparams
691
+ . get ( "weight" )
692
+ . unwrap_or ( & serde_json:: Value :: from ( "uniform" ) )
693
+ . as_str ( )
694
+ . unwrap_or ( "uniform" )
695
+ {
696
+ "distance" => smartcore:: neighbors:: KNNWeightFunction :: Distance ,
697
+ _ => smartcore:: neighbors:: KNNWeightFunction :: Uniform ,
698
+ } ;
699
+ hyperparam_usize ! ( k, hyperparams, 3 ) ;
700
+
701
+ let estimator: Option < Box < dyn Estimator > > =match project. task {
702
+ Task :: regression =>Some ( Box :: new (
703
+ smartcore:: neighbors:: knn_regressor:: KNNRegressor :: fit (
704
+ & x_train,
705
+ & y_train,
706
+ smartcore:: neighbors:: knn_regressor:: KNNRegressorParameters :: default ( )
707
+ . with_algorithm ( algorithm)
708
+ . with_weight ( weight)
709
+ . with_k ( k) ,
710
+ )
711
+ . unwrap ( ) ,
712
+ ) ) ,
713
+
714
+ Task :: classification =>Some ( Box :: new (
715
+ smartcore:: neighbors:: knn_classifier:: KNNClassifier :: fit (
716
+ & x_train,
717
+ & y_train,
718
+ smartcore:: neighbors:: knn_classifier:: KNNClassifierParameters :: default (
719
+ )
720
+ . with_algorithm ( algorithm)
721
+ . with_weight ( weight)
722
+ . with_k ( k) ,
723
+ )
724
+ . unwrap ( ) ,
725
+ ) ) ,
726
+ } ;
727
+
728
+ save_estimator ! ( estimator, self ) ;
729
+
730
+ estimator
731
+ }
732
+
733
+ Algorithm :: random_forest =>{
734
+ train_test_split ! ( dataset, x_train, y_train) ;
735
+
736
+ let max_depth =match hyperparams. get ( "max_depth" ) {
737
+ Some ( max_depth) =>match max_depth. as_u64 ( ) {
738
+ Some ( max_depth) =>Some ( max_depthas u16 ) ,
739
+ None =>None ,
740
+ } ,
741
+ None =>None ,
742
+ } ;
743
+
744
+ let m =match hyperparams. get ( "m" ) {
745
+ Some ( m) =>match m. as_u64 ( ) {
746
+ Some ( m) =>Some ( mas usize ) ,
747
+ None =>None ,
748
+ } ,
749
+ None =>None ,
750
+ } ;
751
+
752
+ let split_criterion =match hyperparams
753
+ . get ( "split_criterion" )
754
+ . unwrap_or ( & serde_json:: Value :: from ( "gini" ) )
755
+ . as_str ( )
756
+ . unwrap_or ( "gini" ) {
757
+ "entropy" => smartcore:: tree:: decision_tree_classifier:: SplitCriterion :: Entropy ,
758
+ "classification_error" => smartcore:: tree:: decision_tree_classifier:: SplitCriterion :: ClassificationError ,
759
+ _ => smartcore:: tree:: decision_tree_classifier:: SplitCriterion :: Gini ,
760
+ } ;
761
+
762
+ hyperparam_usize ! ( min_samples_leaf, hyperparams, 1 ) ;
763
+ hyperparam_usize ! ( min_samples_split, hyperparams, 2 ) ;
764
+ hyperparam_usize ! ( n_trees, hyperparams, 10 ) ;
765
+ hyperparam_usize ! ( seed, hyperparams, 0 ) ;
766
+ hyperparam_bool ! ( keep_samples, hyperparams, false ) ;
767
+
768
+ let estimator: Option < Box < dyn Estimator > > =match project. task {
769
+ Task :: regression =>{
770
+ let mut params = smartcore:: ensemble:: random_forest_regressor:: RandomForestRegressorParameters :: default ( )
771
+ . with_min_samples_leaf ( min_samples_leaf)
772
+ . with_min_samples_split ( min_samples_split)
773
+ . with_seed ( seedas u64 )
774
+ . with_n_trees ( n_treesas usize )
775
+ . with_keep_samples ( keep_samples) ;
776
+ match max_depth{
777
+ Some ( max_depth) => params = params. with_max_depth ( max_depth) ,
778
+ None =>( ) ,
779
+ } ;
780
+
781
+ match m{
782
+ Some ( m) => params = params. with_m ( m) ,
783
+ None =>( ) ,
784
+ } ;
785
+
786
+ Some (
787
+ Box :: new (
788
+ smartcore:: ensemble:: random_forest_regressor:: RandomForestRegressor :: fit (
789
+ & x_train,
790
+ & y_train,
791
+ params,
792
+ ) . unwrap ( )
793
+ )
794
+ )
795
+ }
796
+
797
+ Task :: classification =>{
798
+ let mut params = smartcore:: ensemble:: random_forest_classifier:: RandomForestClassifierParameters :: default ( )
799
+ . with_min_samples_leaf ( min_samples_leaf)
800
+ . with_min_samples_split ( min_samples_leaf)
801
+ . with_seed ( seedas u64 )
802
+ . with_n_trees ( n_treesas u16 )
803
+ . with_keep_samples ( keep_samples)
804
+ . with_criterion ( split_criterion) ;
805
+
806
+ match max_depth{
807
+ Some ( max_depth) => params = params. with_max_depth ( max_depth) ,
808
+ None =>( ) ,
809
+ } ;
810
+
811
+ match m{
812
+ Some ( m) => params = params. with_m ( m) ,
813
+ None =>( ) ,
814
+ } ;
815
+
816
+ Some (
817
+ Box :: new (
818
+ smartcore:: ensemble:: random_forest_classifier:: RandomForestClassifier :: fit (
819
+ & x_train,
820
+ & y_train,
821
+ params,
822
+ ) . unwrap ( )
823
+ )
824
+ )
825
+ }
826
+ } ;
827
+
828
+ save_estimator ! ( estimator, self ) ;
829
+
830
+ estimator
831
+ }
631
832
} ;
632
833
}
633
834