Movatterモバイル変換


[0]ホーム

URL:


ContentsMenuExpandLight modeDark modeAuto light/dark, in light modeAuto light/dark, in dark modeSkip to content
IBM watsonx.ai
IBM watsonx.ai
Back to top

Source code for ibm_watsonx_ai.training

#  -----------------------------------------------------------------------------------------#  (C) Copyright IBM Corp. 2023-2025.#  https://opensource.org/licenses/BSD-3-Clause#  -----------------------------------------------------------------------------------------from__future__importannotationsfromtypingimportAny,TYPE_CHECKING,TypeAlias,Iterator,Literal,Callableimportjsonimportloggingimporttimefromwarningsimportwarnfromibm_boto3.exceptionsimportBoto3ErrorfromlomondimportWebSocketimportibm_watsonx_ai._wrappers.requestsasrequestsfromibm_watsonx_ai.messages.messagesimportMessagesfromibm_watsonx_ai.metanamesimportTrainingConfigurationMetaNamesfromibm_watsonx_ai.utilsimport(print_text_header_h1,print_text_header_h2,TRAINING_RUN_DETAILS_TYPE,StatusLogger,)fromibm_watsonx_ai.utils.utilsimport_get_id_from_deprecated_uid,_handle_fl_removalfromibm_watsonx_ai.wml_client_errorimportWMLClientError,ApiRequestFailurefromibm_watsonx_ai.wml_resourceimportWMLResourcelogging.getLogger("lomond").setLevel(logging.CRITICAL)ListType:TypeAlias=listifTYPE_CHECKING:fromibm_watsonx_aiimportAPIClientfrompandasimportDataFrame
[docs]classTraining(WMLResource):"""Train new models."""def__init__(self,client:APIClient)->None:WMLResource.__init__(self,__name__,client)self.ConfigurationMetaNames=TrainingConfigurationMetaNames()
[docs]defget_status(self,training_id:str|None=None,**kwargs:Any)->dict:"""Get the status of a created training. :param training_id: ID of the training :type training_id: str :return: training_status :rtype: dict **Example:** .. code-block:: python training_status = client.training.get_status(training_id) """training_id=_get_id_from_deprecated_uid(kwargs,training_id,"training",can_be_none=False)_is_fine_tuning=kwargs.get("_is_fine_tuning",False)Training._validate_type(training_id,"training_id",str,True)details=self.get_details(training_id,_internal=True,_is_fine_tuning=_is_fine_tuning)ifdetailsisnotNone:returnWMLResource._get_required_element_from_dict(details,"details",["entity","status"])else:raiseWMLClientError("Getting trained model status failed. Unable to get model details for training_id: '{}'.".format(training_id))
[docs]defget_details(self,training_id:str|None=None,limit:int|None=None,asynchronous:Literal[True,False]=False,get_all:Literal[True,False]=False,training_type:str|None=None,state:str|None=None,tag_value:str|None=None,training_definition_id:str|None=None,_internal:bool=False,**kwargs:Any,)->dict:"""Get metadata of training(s). If training_id is not specified, the metadata of all model spaces are returned. :param training_id: unique ID of the training :type training_id: str, optional :param limit: limit number of fetched records :type limit: int, optional :param asynchronous: if `True`, it will work as a generator :type asynchronous: bool, optional :param get_all: if `True`, it will get all entries in 'limited' chunks :type get_all: bool, optional :param training_type: filter the fetched list of trainings based on the training type ["pipeline" or "experiment"] :type training_type: str, optional :param state: filter the fetched list of training based on their state: [`queued`, `running`, `completed`, `failed`] :type state: str, optional :param tag_value: filter the fetched list of training based on their tag value :type tag_value: str, optional :param training_definition_id: filter the fetched trainings that are using the given training definition :type training_definition_id: str, optional :return: metadata of training(s) :rtype: - **dict** - if training_id is not None - **{"resources": [dict]}** - if training_id is None **Examples** .. code-block:: python training_run_details = client.training.get_details(training_id) training_runs_details = client.training.get_details() training_runs_details = client.training.get_details(limit=100) training_runs_details = client.training.get_details(limit=100, get_all=True) training_runs_details = [] for entry in client.training.get_details(limit=100, asynchronous=True, get_all=True): training_runs_details.extend(entry) """training_id=_get_id_from_deprecated_uid(kwargs,training_id,"training",can_be_none=True)_is_fine_tuning=kwargs.get("_is_fine_tuning",False)# For CP4D, check if either spce or project ID is setself._client._check_if_either_is_set()Training._validate_type(training_id,"training_id",str,False)if_is_fine_tuning:url=self._client._href_definitions.get_fine_tunings_href()else:url=self._client._href_definitions.get_trainings_href()iftraining_idisNone:query_params:dict|None={param_name:param_valueforparam_name,param_valuein(("type",training_type),("state",state),("tag.value",tag_value),("training_definition_id",training_definition_id),)ifparam_valueisnotNone}# note: If query params is an empty dict convert it back to None valuequery_params=query_paramsifquery_params!={}elseNonereturnself._get_artifact_details(base_url=url,id=training_id,limit=limit,resource_name="trained models",_async=asynchronous,_all=get_all,query_params=query_params,)else:returnself._get_artifact_details(url,training_id,limit,"trained models")
[docs]@staticmethoddefget_href(training_details:dict)->str:"""Get the training href from the training details. :param training_details: metadata of the created training :type training_details: dict :return: training href :rtype: str **Example:** .. code-block:: python training_details = client.training.get_details(training_id) run_url = client.training.get_href(training_details) """Training._validate_type(training_details,"training_details",object,True)if"id"intraining_details.get("metadata",{}):training_id=WMLResource._get_required_element_from_dict(training_details,"training_details",["metadata","id"])return"/ml/v4/trainings/"+training_idelse:Training._validate_type_of_details(training_details,TRAINING_RUN_DETAILS_TYPE)returnWMLResource._get_required_element_from_dict(training_details,"training_details",["metadata","href"])
[docs]@staticmethoddefget_id(training_details:dict)->str:"""Get the training ID from the training details. :param training_details: metadata of the created training :type training_details: dict :return: unique ID of the training :rtype: str **Example:** .. code-block:: python training_details = client.training.get_details(training_id) training_id = client.training.get_id(training_details) """Training._validate_type(training_details,"training_details",object,True)returnWMLResource._get_required_element_from_dict(training_details,"training_details",["metadata","id"])
[docs]defrun(self,meta_props:dict,asynchronous:bool=True,**kwargs:Any)->dict:"""Create a new Machine Learning training. :param meta_props: metadata of the training configuration. To see available meta names, use: .. code-block:: python client.training.ConfigurationMetaNames.show() :type meta_props: dict :param asynchronous: * `True` - training job is submitted and progress can be checked later * `False` - method will wait till job completion and print training stats :type asynchronous: bool, optional :return: metadata of the training created :rtype: dict .. note:: You can provide one of the following values for training: * client.training.ConfigurationMetaNames.EXPERIMENT * client.training.ConfigurationMetaNames.PIPELINE * client.training.ConfigurationMetaNames.MODEL_DEFINITION **Examples** Example of meta_props for creating a training run in IBM Cloud Pak® for Data version 3.0.1 or above: .. code-block:: python metadata = { client.training.ConfigurationMetaNames.NAME: 'Hand-written Digit Recognition', client.training.ConfigurationMetaNames.DESCRIPTION: 'Hand-written Digit Recognition Training', client.training.ConfigurationMetaNames.PIPELINE: { "id": "4cedab6d-e8e4-4214-b81a-2ddb122db2ab", "rev": "12", "model_type": "string", "data_bindings": [ { "data_reference_name": "string", "node_id": "string" } ], "nodes_parameters": [ { "node_id": "string", "parameters": {} } ], "hardware_spec": { "id": "4cedab6d-e8e4-4214-b81a-2ddb122db2ab", "rev": "12", "name": "string", "num_nodes": "2" } }, client.training.ConfigurationMetaNames.TRAINING_DATA_REFERENCES: [{ 'type': 's3', 'connection': {}, 'location': {'href': 'v2/assets/asset1233456'}, 'schema': { 'id': 't1', 'name': 'Tasks', 'fields': [ { 'name': 'duration', 'type': 'number' } ]} }], client.training.ConfigurationMetaNames.TRAINING_RESULTS_REFERENCE: { 'id' : 'string', 'connection': { 'endpoint_url': 'https://s3-api.us-geo.objectstorage.service.networklayer.com', 'access_key_id': '***', 'secret_access_key': '***' }, 'location': { 'bucket': 'wml-dev-results', 'path' : "path" } 'type': 's3' } } Example of a Federated Learning training job: .. code-block:: python aggregator_metadata = { client.training.ConfigurationMetaNames.NAME: 'Federated_Learning_Tensorflow_MNIST', client.training.ConfigurationMetaNames.DESCRIPTION: 'MNIST digit recognition with Federated Learning using Tensorflow', client.training.ConfigurationMetaNames.TRAINING_DATA_REFERENCES: [], client.training.ConfigurationMetaNames.TRAINING_RESULTS_REFERENCE: { 'type': results_type, 'name': 'outputData', 'connection': {}, 'location': { 'path': '/projects/' + PROJECT_ID + '/assets/trainings/'} }, client.training.ConfigurationMetaNames.FEDERATED_LEARNING: { 'model': { 'type': 'tensorflow', 'spec': { 'id': untrained_model_id }, 'model_file': untrained_model_name }, 'fusion_type': 'iter_avg', 'metrics': 'accuracy', 'epochs': 3, 'rounds': 10, 'remote_training' : { 'quorum': 1.0, 'max_timeout': 3600, 'remote_training_systems': [ { 'id': prime_rts_id }, { 'id': nonprime_rts_id} ] }, 'hardware_spec': { 'name': 'S' }, 'software_spec': { 'name': 'runtime-22.1-py3.9' } } aggregator = client.training.run(aggregator_metadata, asynchronous=True) aggregator_id = client.training.get_id(aggregator) """# For CP4D, check if either spce or project ID is setself._client._check_if_either_is_set()Training._validate_type(meta_props,"meta_props",object,True)Training._validate_type(asynchronous,"asynchronous",bool,True)_is_fine_tuning=kwargs.get("_is_fine_tuning",False)self.ConfigurationMetaNames._validate(meta_props)training_configuration_metadata={"training_data_references":meta_props[self.ConfigurationMetaNames.TRAINING_DATA_REFERENCES],"results_reference":meta_props[self.ConfigurationMetaNames.TRAINING_RESULTS_REFERENCE],}ifself.ConfigurationMetaNames.TEST_DATA_REFERENCESinmeta_props:training_configuration_metadata["test_data_references"]=meta_props[self.ConfigurationMetaNames.TEST_DATA_REFERENCES]ifself.ConfigurationMetaNames.TEST_OUTPUT_DATAinmeta_props:training_configuration_metadata["test_output_data"]=meta_props[self.ConfigurationMetaNames.TEST_OUTPUT_DATA]ifself.ConfigurationMetaNames.TAGSinmeta_props:training_configuration_metadata["tags"]=meta_props[self.ConfigurationMetaNames.TAGS]ifself.ConfigurationMetaNames.PROMPT_TUNINGinmeta_props:training_configuration_metadata["prompt_tuning"]=meta_props[self.ConfigurationMetaNames.PROMPT_TUNING]ifself.ConfigurationMetaNames.FINE_TUNINGinmeta_props:training_configuration_metadata["parameters"]=meta_props[self.ConfigurationMetaNames.FINE_TUNING]ifself.ConfigurationMetaNames.AUTO_UPDATE_MODELinmeta_props:training_configuration_metadata["auto_update_model"]=meta_props[self.ConfigurationMetaNames.AUTO_UPDATE_MODEL]# TODO remove when training service starts copying such data on their owntraining_configuration_metadata["name"]=meta_props[self.ConfigurationMetaNames.NAME]training_configuration_metadata["description"]=meta_props[self.ConfigurationMetaNames.DESCRIPTION]ifself.ConfigurationMetaNames.PIPELINEinmeta_props:training_configuration_metadata["pipeline"]=meta_props[self.ConfigurationMetaNames.PIPELINE]ifself.ConfigurationMetaNames.EXPERIMENTinmeta_props:training_configuration_metadata["experiment"]=meta_props[self.ConfigurationMetaNames.EXPERIMENT]ifself.ConfigurationMetaNames.MODEL_DEFINITIONinmeta_props:training_configuration_metadata["model_definition"]=meta_props[self.ConfigurationMetaNames.MODEL_DEFINITION]ifself.ConfigurationMetaNames.SPACE_UIDinmeta_props:training_configuration_metadata["space_id"]=meta_props[self.ConfigurationMetaNames.SPACE_UID]if"type"inmeta_props:training_configuration_metadata["type"]=meta_props["type"]if(self._client.default_space_idisNoneandself._client.default_project_idisNone):raiseWMLClientError(Messages.get_message(message_id="it_is_mandatory_to_set_the_space_project_id"))else:ifself._client.default_space_idisnotNone:training_configuration_metadata["space_id"]=(self._client.default_space_id)elifself._client.default_project_idisnotNone:training_configuration_metadata["project_id"]=(self._client.default_project_id)ifself.ConfigurationMetaNames.FEDERATED_LEARNINGinmeta_props:_handle_fl_removal(self._client)training_configuration_metadata["federated_learning"]=meta_props[self.ConfigurationMetaNames.FEDERATED_LEARNING]if_is_fine_tuning:train_endpoint=self._client._href_definitions.get_fine_tunings_href()else:train_endpoint=self._client._href_definitions.get_trainings_href()params=self._client._params()if"space_id"inparams.keys():params.pop("space_id")if"project_id"inparams.keys():params.pop("project_id")ifself._client.ICP_PLATFORM_SPACES:if"userfs"inparams.keys():params.pop("userfs")response_train_post=requests.post(train_endpoint,json=training_configuration_metadata,params=params,headers=self._client._get_headers(),)run_details=self._handle_response(201,"training",response_train_post)trained_model_id=self.get_id(run_details)ifasynchronousisTrue:returnrun_detailselse:print_text_header_h1("Running '{}'".format(trained_model_id))status=self.get_status(trained_model_id,_is_fine_tuning=_is_fine_tuning)state=status["state"]withStatusLogger(state)asstatus_logger:whilestatenotin["error","completed","canceled","failed"]:time.sleep(5)status=self.get_status(trained_model_id,_is_fine_tuning=_is_fine_tuning)state=status["state"]status_logger.log_state(state)if"completed"instate:print("\nTraining of '{}' finished successfully.".format(str(trained_model_id)))else:print("\nTraining of '{}' failed with status: '{}'.".format(trained_model_id,str(status)))self._logger.debug("Response({}):{}".format(state,run_details))returnself.get_details(trained_model_id,_internal=True,_is_fine_tuning=_is_fine_tuning)
[docs]deflist(self,limit:int|None=None,asynchronous:bool=False,get_all:bool=False,)->DataFrame|Iterator|ListType:"""List stored trainings in a table format. :param limit: limit number of fetched records :type limit: int, optional :param asynchronous: if `True`, it will work as a generator :type asynchronous: bool, optional :param get_all: if `True`, it will get all entries in 'limited' chunks :type get_all: bool, optional :return: pandas.DataFrame with listed experiments :rtype: pandas.DataFrame **Examples** .. code-block:: python client.training.list() training_runs_df = client.training.list(limit=100) training_runs_df = client.training.list(limit=100, get_all=True) training_runs_df = [] for entry in client.training.list(limit=100, asynchronous=True, get_all=True): training_runs_df.extend(entry) """# For CP4D, check if either spce or project ID is setself._client._check_if_either_is_set()defpreprocess_details(details:dict)->DataFrame|ListType:resources=details["resources"]values=[(m["metadata"].get("id",m["metadata"].get("guid")),m["entity"]["status"]["state"],m["metadata"]["created_at"],)forminresources]returnself._list(values,["ID (training)","STATE","CREATED"],limit=None,sort_by=None,)ifasynchronous:return(preprocess_details(details)fordetailsinself.get_details(limit=limit,asynchronous=asynchronous,get_all=get_all,_internal=True,))else:details=self.get_details(limit=limit,get_all=get_all,_internal=True)table=preprocess_details(details)returntable
[docs]deflist_intermediate_models(self,training_id:str|None=None,**kwargs:Any)->None:"""Print the intermediate_models in a table format. :param training_id: ID of the training :type training_id: str .. note:: This method is not supported for IBM Cloud Pak® for Data. **Example:** .. code-block:: python client.training.list_intermediate_models() """training_id=_get_id_from_deprecated_uid(kwargs,training_id,"training",can_be_none=False)# For CP4D, check if either spce or project ID is setifself._client.ICP_PLATFORM_SPACES:raiseWMLClientError("This method is not supported for IBM Cloud Pak® for Data. ")self._client._check_if_either_is_set()details=self.get_details(training_id,_internal=True)# if status is completed then only lists global_output else display message saying "state value"training_state=details["entity"]["status"]["state"]iftraining_state=="completed":if("metrics"indetails["entity"]["status"]anddetails["entity"]["status"].get("metrics")isnotNone):metrics_list=details["entity"]["status"]["metrics"]new_list=[]formlinmetrics_list:if"context"inmland"intermediate_model"inml["context"]:name=ml["context"]["intermediate_model"].get("name","")if"location"inml["context"]["intermediate_model"]:path=ml["context"]["intermediate_model"]["location"].get("model","")else:path=""else:name=""path=""accuracy=ml["ml_metrics"].get("training_accuracy","")F1Micro=round(ml["ml_metrics"].get("training_f1_micro",0),2)F1Macro=round(ml["ml_metrics"].get("training_f1_macro",0),2)F1Weighted=round(ml["ml_metrics"].get("training_f1_weighted",0),2)logLoss=round(ml["ml_metrics"].get("training_neg_log_loss",0),2)PrecisionMicro=round(ml["ml_metrics"].get("training_precision_micro",0),2)PrecisionWeighted=round(ml["ml_metrics"].get("training_precision_weighted",0),2)PrecisionMacro=round(ml["ml_metrics"].get("training_precision_macro",0),2)RecallMacro=round(ml["ml_metrics"].get("training_recall_macro",0),2)RecallMicro=round(ml["ml_metrics"].get("training_recall_micro",0),2)RecallWeighted=round(ml["ml_metrics"].get("training_recall_weighted",0),2)createdAt=details["metadata"]["created_at"]new_list.append([name,path,accuracy,F1Micro,F1Macro,F1Weighted,logLoss,PrecisionMicro,PrecisionMacro,PrecisionWeighted,RecallMicro,RecallMacro,RecallWeighted,createdAt,])new_list.append([])fromtabulateimporttabulateheader=["NAME","PATH","Accuracy","F1Micro","F1Macro","F1Weighted","LogLoss","PrecisionMicro","PrecisionMacro","PrecisionWeighted","RecallMicro","RecallMacro","RecallWeighted","CreatedAt",]table=tabulate([header]+new_list)print(table)else:print(" There is no intermediate model metrics are available for this training id. ")else:self._logger.debug("state is not completed")
[docs]defcancel(self,training_id:str|None=None,hard_delete:bool=False,**kwargs:Any,)->Literal["SUCCESS"]:"""Cancel a training that is currently running. This method can delete metadata details of a completed or canceled training run when `hard_delete` parameter is set to `True`. :param training_id: ID of the training :type training_id: str :param hard_delete: specify `True` or `False`: * `True` - to delete the completed or canceled training run * `False` - to cancel the currently running training run :type hard_delete: bool, optional :return: status "SUCCESS" if cancelation is successful :rtype: Literal["SUCCESS"] **Example:** .. code-block:: python client.training.cancel(training_id) """training_id=_get_id_from_deprecated_uid(kwargs,training_id,"training",can_be_none=False)_is_fine_tuning=kwargs.get("_is_fine_tuning",False)# For CP4D, check if either spce or project ID is setself._client._check_if_either_is_set()Training._validate_type(training_id,"training_id",str,True)params=self._client._params()ifhard_deleteisTrue:params.update({"hard_delete":"true"})if_is_fine_tuning:train_endpoint=self._client._href_definitions.get_fine_tuning_href(training_id)else:train_endpoint=self._client._href_definitions.get_training_href(training_id)response_delete=requests.delete(train_endpoint,headers=self._client._get_headers(),params=params,)if(response_delete.status_code==400andresponse_delete.textisnotNoneand"Job already completed with state"inresponse_delete.text):print("Job is not running currently. Please use 'hard_delete=True' parameter to force delete"" completed or canceled training runs.")return"SUCCESS"else:returnself._handle_response(204,"trained model deletion",response_delete,False)
def_COS_logs(self,run_id:str,on_start:Callable=lambda:{})->None:on_start()run_details=self.get_details(run_id,_internal=True)if("connection"inrun_details["entity"]["results_reference"]andrun_details["entity"]["results_reference"].get("connection")isnotNone):endpoint_url=run_details["entity"]["results_reference"]["connection"]["endpoint_url"]aws_access_key=run_details["entity"]["results_reference"]["connection"]["access_key_id"]aws_secret=run_details["entity"]["results_reference"]["connection"]["secret_access_key"]bucket=run_details["entity"]["results_reference"]["location"]["bucket"]ifbucket=="":bucket=run_details["entity"]["results_reference"]["target"]["bucket"]importibm_boto3client_cos=ibm_boto3.client(service_name="s3",aws_access_key_id=aws_access_key,aws_secret_access_key=aws_secret,endpoint_url=endpoint_url,)try:ifself._client.CLOUD_PLATFORM_SPACES:logs=(run_details["entity"].get("results_reference").get("location").get("logs"))iflogsisNone:print(" There is no logs details for this Training run, hence no logs.")returnkey=logs+"/learner-1/training-log.txt"else:try:key=("data/"+run_details["metadata"].get("id",run_details["metadata"].get("guid"))+"/pipeline-model.json")obj=client_cos.get_object(Bucket=bucket,Key=key)pipeline_model=json.loads((obj["Body"].read().decode("utf-8")))exceptibm_boto3.exceptions.ibm_botocore.client.ClientErrorasex:ifex.response["Error"]["Code"]=="NoSuchKey":print(" Error - There is no training logs are found for the given training run id")returnelse:print(ex)returnifpipeline_modelisnotNone:key=(pipeline_model["pipelines"][0]["nodes"][0]["parameters"]["model_id"]+"/learner-1/training-log.txt")else:print(" Error - Cannot find the any logs for the given training run id")obj=client_cos.get_object(Bucket=bucket,Key=key)print(obj["Body"].read().decode("utf-8"))exceptibm_boto3.exceptions.ibm_botocore.client.ClientErrorasex:ifex.response["Error"]["Code"]=="NoSuchKey":print("ERROR - Cannot find training-log.txt in the bucket")else:print(ex)print("ERROR - Cannot get the training run log in the bucket")else:print(" There is no connection details for this Training run, hence no logs.")def_COS_metrics(self,run_id:str,on_start:Callable=lambda:{})->None:on_start()run_details=self.get_details(run_id,_internal=True)endpoint_url=run_details["entity"]["results_reference"]["connection"]["endpoint_url"]aws_access_key=run_details["entity"]["results_reference"]["connection"]["access_key_id"]aws_secret=run_details["entity"]["results_reference"]["connection"]["secret_access_key"]bucket=run_details["entity"]["results_reference"]["location"]["bucket"]ifbucket=="":bucket=run_details["entity"]["results_reference"]["target"]["bucket"]importibm_boto3client_cos=ibm_boto3.client(service_name="s3",aws_access_key_id=aws_access_key,aws_secret_access_key=aws_secret,endpoint_url=endpoint_url,)try:ifself._client.CLOUD_PLATFORM_SPACES:logs=(run_details["entity"].get("results_reference").get("location").get("logs"))iflogsisNone:print(" Metric log location details for this Training run is not available.")returnkey=logs+"/learner-1/evaluation-metrics.txt"else:try:key=(run_details["metadata"].get("id",run_details["metadata"].get("guid"))+"/pipeline-model.json")obj=client_cos.get_object(Bucket=bucket,Key=key)pipeline_model=json.loads((obj["Body"].read().decode("utf-8")))exceptibm_boto3.exceptions.ibm_botocore.client.ClientErrorasex:ifex.response["Error"]["Code"]=="NoSuchKey":print("ERROR - Cannot find pipeline_model.json in the bucket for training id "+run_id)print("There is no training logs are found for the given training run id")returnelse:print(ex)returnkey=(pipeline_model["pipelines"][0]["nodes"][0]["parameters"].get["model_id"]+"/learner-1/evaluation-metrics.txt")obj=client_cos.get_object(Bucket=bucket,Key=key)print(obj["Body"].read().decode("utf-8"))exceptibm_boto3.exceptions.ibm_botocore.client.ClientErrorasex:ifex.response["Error"]["Code"]=="NoSuchKey":print("ERROR - Cannot find evaluation-metrics.txt in the bucket")else:print(ex)print("ERROR - Cannot get the location of evaluation-metrics.txt details in the bucket")
[docs]defmonitor_logs(self,training_id:str|None=None,**kwargs:Any)->None:"""Print the logs of a training created. :param training_id: training ID :type training_id: str .. note:: This method is not supported for IBM Cloud Pak® for Data. **Example:** .. code-block:: python client.training.monitor_logs(training_id) """training_id=_get_id_from_deprecated_uid(kwargs,training_id,"training",can_be_none=False)ifself._client.ICP_PLATFORM_SPACES:raiseWMLClientError("Metrics logs are not supported. This method is not supported for IBM Cloud Pak® for Data.")Training._validate_type(training_id,"training_id",str,True)self._simple_monitor_logs(training_id,# type: ignorelambda:print_text_header_h1("Log monitor started for training run: "+str(training_id)),)print_text_header_h2("Log monitor done.")
def_simple_monitor_logs(self,training_id:str,on_start:Callable=lambda:{})->None:try:run_details=self.get_details(training_id,_internal=True)exceptApiRequestFailureasex:if"404"instr(ex.args[1]):print("Could not find the training run details for the given training run id.")returnelse:raiseexstatus=run_details["entity"]["status"]["state"]if(status=="completed"orstatus=="error"orstatus=="failed"orstatus=="canceled"):self._COS_logs(training_id,lambda:print_text_header_h1("Log monitor started for training run: "+str(training_id)),)else:ifself._client.CLOUD_PLATFORM_SPACES:ws_param=self._client._params()if"project_id"inws_param.keys():proj_id=ws_param.get("project_id")monitor_endpoint=(self._credentials.url.replace("https","wss")+"/ml/v4/trainings/"+training_id+"?project_id="+proj_id)else:space_id=ws_param.get("space_id")monitor_endpoint=(self._credentials.url.replace("https","wss")+"/ml/v4/trainings/"+training_id+"?space_id="+space_id)else:monitor_endpoint=(self._credentials.url.replace("https","wss")+"/v4/trainings/"+training_id)websocket=WebSocket(monitor_endpoint)try:websocket.add_header(bytes("Authorization","utf-8"),bytes("Bearer "+self._client.token,"utf-8"),)except:websocket.add_header(bytes("Authorization","utf-8"),bytes("bearer "+self._client.token),)on_start()foreventinwebsocket:ifevent.name=="text":text=json.loads(event.text)entity=text["entity"]if"status"inentity:if"message"inentity["status"]:message=entity["status"]["message"]iflen(message)>0:print(message)websocket.close()
[docs]defmonitor_metrics(self,training_id:str|None=None,**kwargs:Any)->None:"""Print the metrics of a created training. :param training_id: ID of the training :type training_id: str .. note:: This method is not supported for IBM Cloud Pak® for Data. **Example:** .. code-block:: python client.training.monitor_metrics(training_id) """training_id=_get_id_from_deprecated_uid(kwargs,training_id,"training",can_be_none=False)ifself._client.ICP_PLATFORM_SPACES:raiseWMLClientError("Metrics monitoring is not supported for IBM Cloud Pak® for Data")Training._validate_type(training_id,"training_id",str,True)try:run_details=self.get_details(training_id,_internal=True)exceptApiRequestFailureasex:if"404"instr(ex.args[1]):print("Could not find the training run details for the given training run id. ")returnelse:raiseexstatus=run_details["entity"]["status"]["state"]if(status=="completed"orstatus=="error"orstatus=="failed"orstatus=="canceled"):self._COS_metrics(training_id,# type: ignorelambda:print_text_header_h1("Log monitor started for training run: "+str(training_id)),)else:ifself._client.CLOUD_PLATFORM_SPACES:ws_param=self._client._params()if"project_id"inws_param.keys():proj_id=ws_param.get("project_id")monitor_endpoint=(self._credentials.url.replace("https","wss")+"/ml/v4/trainings/"+training_id+"?project_id="+proj_id)else:space_id=ws_param.get("space_id")monitor_endpoint=(self._credentials.url.replace("https","wss")+"/ml/v4/trainings/"+training_id+"?space_id="+space_id)else:monitor_endpoint=(self._credentials.url.replace("https","wss")+"/v4/trainings/"+training_id)websocket=WebSocket(monitor_endpoint)try:websocket.add_header(bytes("Authorization","utf-8"),bytes("Bearer "+self._client.token,"utf-8"),)except:websocket.add_header(bytes("Authorization","utf-8"),bytes("bearer "+self._client.token),)print_text_header_h1("Metric monitor started for training run: "+str(training_id))foreventinwebsocket:ifevent.name=="text":text=json.loads(event.text)entity=text["entity"]if"status"inentity:status=entity["status"]if"metrics"instatus:metrics=status["metrics"]iflen(metrics)>0:metric=metrics[0]print(metric)websocket.close()print_text_header_h2("Metric monitor done.")
[docs]defget_metrics(self,training_id:str|None=None,**kwargs:Any)->ListType[dict]:"""Get metrics of a training run. :param training_id: ID of the training :type training_id: str :return: metrics of the training run :rtype: list of dict **Example:** .. code-block:: python training_status = client.training.get_metrics(training_id) """training_id=_get_id_from_deprecated_uid(kwargs,training_id,"training",can_be_none=False)Training._validate_type(training_id,"training_id",str,True)status=self.get_status(training_id)if"metrics"instatus:returnstatus["metrics"]else:details=self.get_details(training_id,_internal=True)if"metrics"indetails:returndetails["metrics"]else:raiseWMLClientError("No metrics details are available for the given training_id")

[8]
ページ先頭

©2009-2025 Movatter.jp