@@ -87,6 +87,11 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
87
87
rmp_serde:: from_read ( & * data) . unwrap ( ) ;
88
88
Box :: new ( estimator)
89
89
}
90
+ Algorithm :: elastic_net =>{
91
+ let estimator: smartcore:: linear:: elastic_net:: ElasticNet < f32 , Array2 < f32 > > =
92
+ rmp_serde:: from_read ( & * data) . unwrap ( ) ;
93
+ Box :: new ( estimator)
94
+ }
90
95
Algorithm :: xgboost =>{
91
96
let bst =Booster :: load_buffer ( & * data) . unwrap ( ) ;
92
97
Box :: new ( BoosterBox :: new ( bst) )
@@ -149,6 +154,7 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
149
154
Box :: new ( estimator)
150
155
}
151
156
Algorithm :: lasso =>panic ! ( "Lasso does not support classification" ) ,
157
+ Algorithm :: elastic_net =>panic ! ( "Elastic Net does not support classification" ) ,
152
158
Algorithm :: xgboost =>{
153
159
let bst =Booster :: load_buffer ( & * data) . unwrap ( ) ;
154
160
Box :: new ( BoosterBox :: new ( bst) )
@@ -285,132 +291,35 @@ pub trait Estimator: Send + Sync + Debug {
285
291
fn predict ( & self , features : Vec < f32 > ) ->f32 ;
286
292
}
287
293
288
- #[ typetag:: serialize]
289
- impl Estimator for smartcore:: linear:: linear_regression:: LinearRegression < f32 , Array2 < f32 > > {
290
- fn test ( & self , task : Task , data : & Dataset ) ->HashMap < String , f32 > {
291
- test_smartcore ( self , task, data)
292
- }
293
-
294
- fn predict ( & self , features : Vec < f32 > ) ->f32 {
295
- predict_smartcore ( self , features)
296
- }
297
- }
298
-
299
- #[ typetag:: serialize]
300
- impl Estimator for smartcore:: linear:: logistic_regression:: LogisticRegression < f32 , Array2 < f32 > > {
301
- fn test ( & self , task : Task , data : & Dataset ) ->HashMap < String , f32 > {
302
- test_smartcore ( self , task, data)
303
- }
304
-
305
- fn predict ( & self , features : Vec < f32 > ) ->f32 {
306
- predict_smartcore ( self , features)
307
- }
308
- }
309
-
310
- // All the SVM kernels :popcorn:
311
-
312
- #[ typetag:: serialize]
313
- impl Estimator for smartcore:: svm:: svc:: SVC < f32 , Array2 < f32 > , smartcore:: svm:: LinearKernel > {
314
- fn test ( & self , task : Task , data : & Dataset ) ->HashMap < String , f32 > {
315
- test_smartcore ( self , task, data)
316
- }
317
-
318
- fn predict ( & self , features : Vec < f32 > ) ->f32 {
319
- predict_smartcore ( self , features)
320
- }
321
- }
322
-
323
- #[ typetag:: serialize]
324
- impl Estimator for smartcore:: svm:: svr:: SVR < f32 , Array2 < f32 > , smartcore:: svm:: LinearKernel > {
325
- fn test ( & self , task : Task , data : & Dataset ) ->HashMap < String , f32 > {
326
- test_smartcore ( self , task, data)
327
- }
328
-
329
- fn predict ( & self , features : Vec < f32 > ) ->f32 {
330
- predict_smartcore ( self , features)
331
- }
332
- }
333
-
334
- #[ typetag:: serialize]
335
- impl Estimator for smartcore:: svm:: svc:: SVC < f32 , Array2 < f32 > , smartcore:: svm:: SigmoidKernel < f32 > > {
336
- fn test ( & self , task : Task , data : & Dataset ) ->HashMap < String , f32 > {
337
- test_smartcore ( self , task, data)
338
- }
339
-
340
- fn predict ( & self , features : Vec < f32 > ) ->f32 {
341
- predict_smartcore ( self , features)
342
- }
343
- }
344
-
345
- #[ typetag:: serialize]
346
- impl Estimator for smartcore:: svm:: svr:: SVR < f32 , Array2 < f32 > , smartcore:: svm:: SigmoidKernel < f32 > > {
347
- fn test ( & self , task : Task , data : & Dataset ) ->HashMap < String , f32 > {
348
- test_smartcore ( self , task, data)
349
- }
350
-
351
- fn predict ( & self , features : Vec < f32 > ) ->f32 {
352
- predict_smartcore ( self , features)
353
- }
354
- }
355
-
356
- #[ typetag:: serialize]
357
- impl Estimator
358
- for smartcore:: svm:: svc:: SVC < f32 , Array2 < f32 > , smartcore:: svm:: PolynomialKernel < f32 > >
359
- {
360
- fn test ( & self , task : Task , data : & Dataset ) ->HashMap < String , f32 > {
361
- test_smartcore ( self , task, data)
362
- }
363
-
364
- fn predict ( & self , features : Vec < f32 > ) ->f32 {
365
- predict_smartcore ( self , features)
366
- }
367
- }
368
-
369
- #[ typetag:: serialize]
370
- impl Estimator
371
- for smartcore:: svm:: svr:: SVR < f32 , Array2 < f32 > , smartcore:: svm:: PolynomialKernel < f32 > >
372
- {
373
- fn test ( & self , task : Task , data : & Dataset ) ->HashMap < String , f32 > {
374
- test_smartcore ( self , task, data)
375
- }
376
-
377
- fn predict ( & self , features : Vec < f32 > ) ->f32 {
378
- predict_smartcore ( self , features)
379
- }
380
- }
381
-
382
- #[ typetag:: serialize]
383
- impl Estimator for smartcore:: svm:: svc:: SVC < f32 , Array2 < f32 > , smartcore:: svm:: RBFKernel < f32 > > {
384
- fn test ( & self , task : Task , data : & Dataset ) ->HashMap < String , f32 > {
385
- test_smartcore ( self , task, data)
386
- }
387
-
388
- fn predict ( & self , features : Vec < f32 > ) ->f32 {
389
- predict_smartcore ( self , features)
390
- }
391
- }
392
-
393
- #[ typetag:: serialize]
394
- impl Estimator for smartcore:: svm:: svr:: SVR < f32 , Array2 < f32 > , smartcore:: svm:: RBFKernel < f32 > > {
395
- fn test ( & self , task : Task , data : & Dataset ) ->HashMap < String , f32 > {
396
- test_smartcore ( self , task, data)
397
- }
294
+ /// Implement the Estimator trait (it's always the same)
295
+ /// for all supported algorithms.
296
+ macro_rules! smartcore_estimator_impl{
297
+ ( $estimator: ty) =>{
298
+ #[ typetag:: serialize]
299
+ impl Estimator for $estimator{
300
+ fn test( & self , task: Task , data: & Dataset ) ->HashMap <String , f32 >{
301
+ test_smartcore( self , task, data)
302
+ }
398
303
399
- fn predict ( & self , features : Vec < f32 > ) ->f32 {
400
- predict_smartcore ( self , features)
401
- }
304
+ fn predict( & self , features: Vec <f32 >) ->f32 {
305
+ predict_smartcore( self , features)
306
+ }
307
+ }
308
+ } ;
402
309
}
403
310
404
- #[ typetag:: serialize]
405
- impl Estimator for smartcore:: linear:: lasso:: Lasso < f32 , Array2 < f32 > > {
406
- fn test ( & self , task : Task , data : & Dataset ) ->HashMap < String , f32 > {
407
- test_smartcore ( self , task, data)
408
- }
409
-
410
- fn predict ( & self , features : Vec < f32 > ) ->f32 {
411
- predict_smartcore ( self , features)
412
- }
413
- }
311
+ smartcore_estimator_impl ! ( smartcore:: linear:: linear_regression:: LinearRegression <f32 , Array2 <f32 >>) ;
312
+ smartcore_estimator_impl ! ( smartcore:: linear:: logistic_regression:: LogisticRegression <f32 , Array2 <f32 >>) ;
313
+ smartcore_estimator_impl ! ( smartcore:: svm:: svc:: SVC <f32 , Array2 <f32 >, smartcore:: svm:: LinearKernel >) ;
314
+ smartcore_estimator_impl ! ( smartcore:: svm:: svr:: SVR <f32 , Array2 <f32 >, smartcore:: svm:: LinearKernel >) ;
315
+ smartcore_estimator_impl ! ( smartcore:: svm:: svc:: SVC <f32 , Array2 <f32 >, smartcore:: svm:: SigmoidKernel <f32 >>) ;
316
+ smartcore_estimator_impl ! ( smartcore:: svm:: svr:: SVR <f32 , Array2 <f32 >, smartcore:: svm:: SigmoidKernel <f32 >>) ;
317
+ smartcore_estimator_impl ! ( smartcore:: svm:: svc:: SVC <f32 , Array2 <f32 >, smartcore:: svm:: PolynomialKernel <f32 >>) ;
318
+ smartcore_estimator_impl ! ( smartcore:: svm:: svr:: SVR <f32 , Array2 <f32 >, smartcore:: svm:: PolynomialKernel <f32 >>) ;
319
+ smartcore_estimator_impl ! ( smartcore:: svm:: svc:: SVC <f32 , Array2 <f32 >, smartcore:: svm:: RBFKernel <f32 >>) ;
320
+ smartcore_estimator_impl ! ( smartcore:: svm:: svr:: SVR <f32 , Array2 <f32 >, smartcore:: svm:: RBFKernel <f32 >>) ;
321
+ smartcore_estimator_impl ! ( smartcore:: linear:: lasso:: Lasso <f32 , Array2 <f32 >>) ;
322
+ smartcore_estimator_impl ! ( smartcore:: linear:: elastic_net:: ElasticNet <f32 , Array2 <f32 >>) ;
414
323
415
324
pub struct BoosterBox {
416
325
contents : Box < xgboost:: Booster > ,