@@ -94,7 +94,7 @@ mod pgml_rust {
9494) ,
9595} ;
9696
97- let ( mut x, mut y, mut num_rows) =( vec ! [ ] , vec ! [ ] , 0 ) ;
97+ let ( mut x, mut y, mut num_rows, mut num_features ) =( vec ! [ ] , vec ! [ ] , 0 , 0 ) ;
9898
9999let hyperparams = hyperparams. 0 ;
100100
@@ -131,7 +131,7 @@ mod pgml_rust {
131131. into_iter ( )
132132. map ( |column|format ! ( "CAST({} AS REAL)" , column) )
133133. collect :: < Vec < String > > ( ) ;
134-
134+
135135let query =format ! (
136136"SELECT {}, CAST({} AS REAL) FROM {} ORDER BY RANDOM()" ,
137137 features. clone( ) . join( ", " ) ,
@@ -151,11 +151,22 @@ mod pgml_rust {
151151 num_rows +=1 ;
152152} ) ;
153153
154+ num_features = features. len ( ) ;
155+
154156Ok ( Some ( ( ) ) )
155157} ) ;
156158
157- let mut dtrain =DMatrix :: from_dense ( & x, num_rows) . unwrap ( ) ;
158- dtrain. set_labels ( & y) . unwrap ( ) ;
159+ // todo parameterize test split instead of 0.5
160+ let test_rows =( num_rowsas f32 * 0.5 ) . round ( ) as usize ;
161+ let train_rows = num_rows - test_rows;
162+ let mut dtrain =DMatrix :: from_dense ( & x[ ..train_rows* num_features] , train_rows) . unwrap ( ) ;
163+ let mut dtest =DMatrix :: from_dense ( & x[ train_rows* num_features..] , test_rows) . unwrap ( ) ;
164+ dtrain. set_labels ( & y[ ..train_rows] ) . unwrap ( ) ;
165+ dtest. set_labels ( & y[ train_rows..] ) . unwrap ( ) ;
166+
167+
168+ // specify datasets to evaluate against during training
169+ let evaluation_sets =& [ ( & dtrain, "train" ) , ( & dtest, "test" ) ] ;
159170
160171// configure objectives, metrics, etc.
161172let learning_params = parameters:: learning:: LearningTaskParametersBuilder :: default ( )
@@ -186,8 +197,6 @@ mod pgml_rust {
186197. build ( )
187198. unwrap ( ) ;
188199
189- // specify datasets to evaluate against during training
190- // let evaluation_sets = &[(&dtrain, "train"), (&dtest, "test")];
191200
192201// overall configuration for training/evaluation
193202let params = parameters:: TrainingParametersBuilder :: default ( )
@@ -197,7 +206,7 @@ mod pgml_rust {
197206None =>2 ,
198207} ) // number of training iterations
199208. booster_params ( booster_params) // model parameters
200- // .evaluation_sets(Some(evaluation_sets)) // optional datasets to evaluate against in each iteration
209+ . evaluation_sets ( Some ( evaluation_sets) ) // optional datasets to evaluate against in each iteration
201210. build ( )
202211. unwrap ( ) ;
203212