- Notifications
You must be signed in to change notification settings - Fork25.8k
Description
Clarification of how we should handle array output type when a metric outputs several values (i.e. accepts multiclass or multioutput input).
The issue was summarised succinctly in#30439 (comment):
Not sure what should be the output namespace / device in case we output an array, e.g. roc_auc_score with average=None on multiclass problems...
Currently all regression/classification metrics that support array API and multiclass or multioutput, all output an array in the same namespace and device as the input (checked code and manually). Summary of these metrics :
Regression metrics
Returns array in same namespace/device:
- explained_variance_score
- r2_score
- mean_absolute_error
- mean_absolute_percentage_error
- mean_pinball_loss
- mean_squared_error
- mean_squared_log_error
- root_mean_squared_error
- root_mean_squared_log_error
Classification metrics
Returns array in same namespace/device:
Looking at the metrics code, if we wanted to support a list of scalars, we'd generally have to do extra processing to convert an array (often output of an xp.** function) to a list of scalars.
Once we arrive at a consensus we should update the array API documentation and update thecheck_array_api_metric
in tests such that when the output is array/list - we check that the output type etc is correct.