@@ -160,13 +160,134 @@ impl Model {
160160Some ( value) => value. as_u64 ( ) . unwrap_or ( 2 ) as u32 ,
161161None =>2 ,
162162} )
163- . eta ( 0.3 )
163+ . eta ( match hyperparams. get ( "eta" ) {
164+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 0.3 ) as f32 ,
165+ None =>match hyperparams. get ( "learning_rate" ) {
166+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 0.3 ) as f32 ,
167+ None =>0.3 ,
168+ } ,
169+ } )
170+ . gamma ( match hyperparams. get ( "gamma" ) {
171+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 0.0 ) as f32 ,
172+ None =>match hyperparams. get ( "min_split_loss" ) {
173+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 0.0 ) as f32 ,
174+ None =>0.0 ,
175+ } ,
176+ } )
177+ . min_child_weight ( match hyperparams. get ( "min_child_weight" ) {
178+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 1.0 ) as f32 ,
179+ None =>1.0 ,
180+ } )
181+ . max_delta_step ( match hyperparams. get ( "max_delta_step" ) {
182+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 0.0 ) as f32 ,
183+ None =>0.0 ,
184+ } )
185+ . subsample ( match hyperparams. get ( "subsample" ) {
186+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 1.0 ) as f32 ,
187+ None =>1.0 ,
188+ } )
189+ . lambda ( match hyperparams. get ( "lambda" ) {
190+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 1.0 ) as f32 ,
191+ None =>1.0 ,
192+ } )
193+ . alpha ( match hyperparams. get ( "alpha" ) {
194+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 0.0 ) as f32 ,
195+ None =>0.0 ,
196+ } )
197+ . tree_method ( match hyperparams. get ( "tree_method" ) {
198+ Some ( value) =>match value. as_str ( ) . unwrap_or ( "auto" ) {
199+ "auto" => parameters:: tree:: TreeMethod :: Auto ,
200+ "exact" => parameters:: tree:: TreeMethod :: Exact ,
201+ "approx" => parameters:: tree:: TreeMethod :: Approx ,
202+ "hist" => parameters:: tree:: TreeMethod :: Hist ,
203+ _ => parameters:: tree:: TreeMethod :: Auto ,
204+ } ,
205+
206+ None => parameters:: tree:: TreeMethod :: Auto ,
207+ } )
208+ . sketch_eps ( match hyperparams. get ( "sketch_eps" ) {
209+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 0.03 ) as f32 ,
210+ None =>0.03 ,
211+ } )
212+ . max_leaves ( match hyperparams. get ( "max_leaves" ) {
213+ Some ( value) => value. as_u64 ( ) . unwrap_or ( 0 ) as u32 ,
214+ None =>0 ,
215+ } )
216+ . max_bin ( match hyperparams. get ( "max_bin" ) {
217+ Some ( value) => value. as_u64 ( ) . unwrap_or ( 256 ) as u32 ,
218+ None =>256 ,
219+ } )
220+ . num_parallel_tree ( match hyperparams. get ( "num_parallel_tree" ) {
221+ Some ( value) => value. as_u64 ( ) . unwrap_or ( 1 ) as u32 ,
222+ None =>1 ,
223+ } )
224+ . grow_policy ( match hyperparams. get ( "grow_policy" ) {
225+ Some ( value) =>match value. as_str ( ) . unwrap_or ( "depthwise" ) {
226+ "depthwise" => parameters:: tree:: GrowPolicy :: Depthwise ,
227+ "lossguide" => parameters:: tree:: GrowPolicy :: LossGuide ,
228+ _ => parameters:: tree:: GrowPolicy :: Depthwise ,
229+ } ,
230+
231+ None => parameters:: tree:: GrowPolicy :: Depthwise ,
232+ } )
233+ . build ( )
234+ . unwrap ( ) ;
235+
236+ let linear_params = parameters:: linear:: LinearBoosterParametersBuilder :: default ( )
237+ . alpha ( match hyperparams. get ( "alpha" ) {
238+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 0.0 ) as f32 ,
239+ None =>0.0 ,
240+ } )
241+ . lambda ( match hyperparams. get ( "lambda" ) {
242+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 0.0 ) as f32 ,
243+ None =>0.0 ,
244+ } )
245+ . build ( )
246+ . unwrap ( ) ;
247+
248+ let dart_params = parameters:: dart:: DartBoosterParametersBuilder :: default ( )
249+ . rate_drop ( match hyperparams. get ( "rate_drop" ) {
250+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 0.0 ) as f32 ,
251+ None =>0.0 ,
252+ } )
253+ . one_drop ( match hyperparams. get ( "one_drop" ) {
254+ Some ( value) => value. as_u64 ( ) . unwrap_or ( 0 ) !=0 ,
255+ None =>false ,
256+ } )
257+ . skip_drop ( match hyperparams. get ( "skip_drop" ) {
258+ Some ( value) => value. as_f64 ( ) . unwrap_or ( 0.0 ) as f32 ,
259+ None =>0.0 ,
260+ } )
261+ . sample_type ( match hyperparams. get ( "sample_type" ) {
262+ Some ( value) =>match value. as_str ( ) . unwrap_or ( "uniform" ) {
263+ "uniform" => parameters:: dart:: SampleType :: Uniform ,
264+ "weighted" => parameters:: dart:: SampleType :: Weighted ,
265+ _ => parameters:: dart:: SampleType :: Uniform ,
266+ } ,
267+ None => parameters:: dart:: SampleType :: Uniform ,
268+ } )
269+ . normalize_type ( match hyperparams. get ( "normalize_type" ) {
270+ Some ( value) =>match value. as_str ( ) . unwrap_or ( "tree" ) {
271+ "tree" => parameters:: dart:: NormalizeType :: Tree ,
272+ "forest" => parameters:: dart:: NormalizeType :: Forest ,
273+ _ => parameters:: dart:: NormalizeType :: Tree ,
274+ } ,
275+ None => parameters:: dart:: NormalizeType :: Tree ,
276+ } )
164277. build ( )
165278. unwrap ( ) ;
166279
167280// overall configuration for Booster
168281let booster_params = parameters:: BoosterParametersBuilder :: default ( )
169- . booster_type ( parameters:: BoosterType :: Tree ( tree_params) )
282+ . booster_type ( match hyperparams. get ( "booster" ) {
283+ Some ( value) =>match value. as_str ( ) . unwrap_or ( "gbtree" ) {
284+ "gbtree" => parameters:: BoosterType :: Tree ( tree_params) ,
285+ "linear" => parameters:: BoosterType :: Linear ( linear_params) ,
286+ "dart" => parameters:: BoosterType :: Dart ( dart_params) ,
287+ _ => parameters:: BoosterType :: Tree ( tree_params) ,
288+ } ,
289+ None => parameters:: BoosterType :: Tree ( tree_params) ,
290+ } )
170291. learning_params ( learning_params)
171292. verbose ( true )
172293. build ( )