@@ -160,13 +160,134 @@ impl Model {
160
160
Some ( value) => value. as_u64 ( ) . unwrap_or ( 2 ) as u32 ,
161
161
None =>2 ,
162
162
} )
163
- . eta ( 0.3 )
163
+ . eta ( match hyperparams. get ( "eta" ) {
164
+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 0.3 ) as f32 ,
165
+ None =>match hyperparams. get ( "learning_rate" ) {
166
+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 0.3 ) as f32 ,
167
+ None =>0.3 ,
168
+ } ,
169
+ } )
170
+ . gamma ( match hyperparams. get ( "gamma" ) {
171
+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 0.0 ) as f32 ,
172
+ None =>match hyperparams. get ( "min_split_loss" ) {
173
+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 0.0 ) as f32 ,
174
+ None =>0.0 ,
175
+ } ,
176
+ } )
177
+ . min_child_weight ( match hyperparams. get ( "min_child_weight" ) {
178
+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 1.0 ) as f32 ,
179
+ None =>1.0 ,
180
+ } )
181
+ . max_delta_step ( match hyperparams. get ( "max_delta_step" ) {
182
+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 0.0 ) as f32 ,
183
+ None =>0.0 ,
184
+ } )
185
+ . subsample ( match hyperparams. get ( "subsample" ) {
186
+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 1.0 ) as f32 ,
187
+ None =>1.0 ,
188
+ } )
189
+ . lambda ( match hyperparams. get ( "lambda" ) {
190
+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 1.0 ) as f32 ,
191
+ None =>1.0 ,
192
+ } )
193
+ . alpha ( match hyperparams. get ( "alpha" ) {
194
+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 0.0 ) as f32 ,
195
+ None =>0.0 ,
196
+ } )
197
+ . tree_method ( match hyperparams. get ( "tree_method" ) {
198
+ Some ( value) =>match value. as_str ( ) . unwrap_or ( "auto" ) {
199
+ "auto" => parameters:: tree:: TreeMethod :: Auto ,
200
+ "exact" => parameters:: tree:: TreeMethod :: Exact ,
201
+ "approx" => parameters:: tree:: TreeMethod :: Approx ,
202
+ "hist" => parameters:: tree:: TreeMethod :: Hist ,
203
+ _ => parameters:: tree:: TreeMethod :: Auto ,
204
+ } ,
205
+
206
+ None => parameters:: tree:: TreeMethod :: Auto ,
207
+ } )
208
+ . sketch_eps ( match hyperparams. get ( "sketch_eps" ) {
209
+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 0.03 ) as f32 ,
210
+ None =>0.03 ,
211
+ } )
212
+ . max_leaves ( match hyperparams. get ( "max_leaves" ) {
213
+ Some ( value) => value. as_u64 ( ) . unwrap_or ( 0 ) as u32 ,
214
+ None =>0 ,
215
+ } )
216
+ . max_bin ( match hyperparams. get ( "max_bin" ) {
217
+ Some ( value) => value. as_u64 ( ) . unwrap_or ( 256 ) as u32 ,
218
+ None =>256 ,
219
+ } )
220
+ . num_parallel_tree ( match hyperparams. get ( "num_parallel_tree" ) {
221
+ Some ( value) => value. as_u64 ( ) . unwrap_or ( 1 ) as u32 ,
222
+ None =>1 ,
223
+ } )
224
+ . grow_policy ( match hyperparams. get ( "grow_policy" ) {
225
+ Some ( value) =>match value. as_str ( ) . unwrap_or ( "depthwise" ) {
226
+ "depthwise" => parameters:: tree:: GrowPolicy :: Depthwise ,
227
+ "lossguide" => parameters:: tree:: GrowPolicy :: LossGuide ,
228
+ _ => parameters:: tree:: GrowPolicy :: Depthwise ,
229
+ } ,
230
+
231
+ None => parameters:: tree:: GrowPolicy :: Depthwise ,
232
+ } )
233
+ . build ( )
234
+ . unwrap ( ) ;
235
+
236
+ let linear_params = parameters:: linear:: LinearBoosterParametersBuilder :: default ( )
237
+ . alpha ( match hyperparams. get ( "alpha" ) {
238
+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 0.0 ) as f32 ,
239
+ None =>0.0 ,
240
+ } )
241
+ . lambda ( match hyperparams. get ( "lambda" ) {
242
+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 0.0 ) as f32 ,
243
+ None =>0.0 ,
244
+ } )
245
+ . build ( )
246
+ . unwrap ( ) ;
247
+
248
+ let dart_params = parameters:: dart:: DartBoosterParametersBuilder :: default ( )
249
+ . rate_drop ( match hyperparams. get ( "rate_drop" ) {
250
+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 0.0 ) as f32 ,
251
+ None =>0.0 ,
252
+ } )
253
+ . one_drop ( match hyperparams. get ( "one_drop" ) {
254
+ Some ( value) => value. as_u64 ( ) . unwrap_or ( 0 ) !=0 ,
255
+ None =>false ,
256
+ } )
257
+ . skip_drop ( match hyperparams. get ( "skip_drop" ) {
258
+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 0.0 ) as f32 ,
259
+ None =>0.0 ,
260
+ } )
261
+ . sample_type ( match hyperparams. get ( "sample_type" ) {
262
+ Some ( value) =>match value. as_str ( ) . unwrap_or ( "uniform" ) {
263
+ "uniform" => parameters:: dart:: SampleType :: Uniform ,
264
+ "weighted" => parameters:: dart:: SampleType :: Weighted ,
265
+ _ => parameters:: dart:: SampleType :: Uniform ,
266
+ } ,
267
+ None => parameters:: dart:: SampleType :: Uniform ,
268
+ } )
269
+ . normalize_type ( match hyperparams. get ( "normalize_type" ) {
270
+ Some ( value) =>match value. as_str ( ) . unwrap_or ( "tree" ) {
271
+ "tree" => parameters:: dart:: NormalizeType :: Tree ,
272
+ "forest" => parameters:: dart:: NormalizeType :: Forest ,
273
+ _ => parameters:: dart:: NormalizeType :: Tree ,
274
+ } ,
275
+ None => parameters:: dart:: NormalizeType :: Tree ,
276
+ } )
164
277
. build ( )
165
278
. unwrap ( ) ;
166
279
167
280
// overall configuration for Booster
168
281
let booster_params = parameters:: BoosterParametersBuilder :: default ( )
169
- . booster_type ( parameters:: BoosterType :: Tree ( tree_params) )
282
+ . booster_type ( match hyperparams. get ( "booster" ) {
283
+ Some ( value) =>match value. as_str ( ) . unwrap_or ( "gbtree" ) {
284
+ "gbtree" => parameters:: BoosterType :: Tree ( tree_params) ,
285
+ "linear" => parameters:: BoosterType :: Linear ( linear_params) ,
286
+ "dart" => parameters:: BoosterType :: Dart ( dart_params) ,
287
+ _ => parameters:: BoosterType :: Tree ( tree_params) ,
288
+ } ,
289
+ None => parameters:: BoosterType :: Tree ( tree_params) ,
290
+ } )
170
291
. learning_params ( learning_params)
171
292
. verbose ( true )
172
293
. build ( )