You signed in with another tab or window.Reload to refresh your session.You signed out in another tab or window.Reload to refresh your session.You switched accounts on another tab or window.Reload to refresh your session.Dismiss alert
float or int: scores for regressions or ints for classifications
400
484
"""
485
+
# TODO: add metrics for tracking prediction volume/accuracy by model
401
486
returnself.algorithm.predict(data)
402
487
403
488
@@ -406,6 +491,7 @@ def train(
406
491
objective:str,
407
492
relation_name:str,
408
493
y_column_name:str,
494
+
algorithm_name:str="linear",
409
495
test_size:floatorint=0.1,
410
496
test_sampling:str="random",
411
497
):
@@ -416,15 +502,14 @@ def train(
416
502
objective (str): Defaults to "regression". Valid values are ["regression", "classification"].
417
503
relation_name (str): the table or view that stores the training data
418
504
y_column_name (str): the column in the training data that acts as the label
419
-
algorithm (str, optional): the algorithm used to implement the objective. Defaults to "linear". Valid values are ["linear", "random_forest"].
505
+
algorithm_name (str, optional): the algorithm used to implement the objective. Defaults to "linear". Valid values are ["linear", "svm", "random_forest", "gradient_boosting"].
420
506
test_size (float or int, optional): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test samples. If None, the value is set to the complement of the train size. If train_size is also None, it will be set to 0.25.
421
507
test_sampling: (str, optional): How to sample to create the test data. Defaults to "random". Valid values are ["first", "last", "random"].
422
508
"""
423
-
ifobjective=="regression":
424
-
algorithms= ["linear","random_forest"]
425
-
elifobjective=="classification":
426
-
algorithms= ["random_forest"]
427
-
else:
509
+
ifalgorithm_nameisNone:
510
+
algorithm_name="linear"
511
+
512
+
ifobjectivenotin ["regression","classification"]:
428
513
raisePgMLException(
429
514
f"Unknown objective `{objective}`, available options are: regression, classification."