Note
Go to the endto download the full example code.
Demonstration for parsing JSON/UBJSON tree model files
SeeIntroduction to Model IO for details about the model serialization.
importargparseimportjsonfromdataclassesimportdataclassfromenumimportIntEnum,uniquefromtypingimportAny,Dict,List,Sequence,Unionimportnumpyasnptry:importubjsonexceptImportError:ubjson=NoneParamT=Dict[str,str]defto_integers(data:Union[bytes,List[int]])->List[int]:"""Convert a sequence of bytes to a list of Python integer"""return[vforvindata]@uniqueclassSplitType(IntEnum):numerical=0categorical=1@dataclassclassNode:# propertiesleft:intright:intparent:intsplit_idx:intsplit_cond:floatdefault_left:boolsplit_type:SplitTypecategories:List[int]# statisticbase_weight:floatloss_chg:floatsum_hess:floatclassTree:"""A tree built by XGBoost."""def__init__(self,tree_id:int,nodes:Sequence[Node])->None:self.tree_id=tree_idself.nodes=nodesdefloss_change(self,node_id:int)->float:"""Loss gain of a node."""returnself.nodes[node_id].loss_chgdefsum_hessian(self,node_id:int)->float:"""Sum Hessian of a node."""returnself.nodes[node_id].sum_hessdefbase_weight(self,node_id:int)->float:"""Base weight of a node."""returnself.nodes[node_id].base_weightdefsplit_index(self,node_id:int)->int:"""Split feature index of node."""returnself.nodes[node_id].split_idxdefsplit_condition(self,node_id:int)->float:"""Split value of a node."""returnself.nodes[node_id].split_conddefsplit_categories(self,node_id:int)->List[int]:"""Categories in a node."""returnself.nodes[node_id].categoriesdefis_categorical(self,node_id:int)->bool:"""Whether a node has categorical split."""returnself.nodes[node_id].split_type==SplitType.categoricaldefis_numerical(self,node_id:int)->bool:returnnotself.is_categorical(node_id)defparent(self,node_id:int)->int:"""Parent ID of a node."""returnself.nodes[node_id].parentdefleft_child(self,node_id:int)->int:"""Left child ID of a node."""returnself.nodes[node_id].leftdefright_child(self,node_id:int)->int:"""Right child ID of a node."""returnself.nodes[node_id].rightdefis_leaf(self,node_id:int)->bool:"""Whether a node is leaf."""returnself.nodes[node_id].left==-1defis_deleted(self,node_id:int)->bool:"""Whether a node is deleted."""returnself.split_index(node_id)==np.iinfo(np.uint32).maxdef__str__(self)->str:stack=[0]nodes=[]whilestack:node:Dict[str,Union[float,int,List[int]]]={}nid=stack.pop()node["node id"]=nidnode["gain"]=self.loss_change(nid)node["cover"]=self.sum_hessian(nid)nodes.append(node)ifnotself.is_leaf(nid)andnotself.is_deleted(nid):left=self.left_child(nid)right=self.right_child(nid)stack.append(left)stack.append(right)categories=self.split_categories(nid)ifcategories:assertself.is_categorical(nid)node["categories"]=categorieselse:assertself.is_numerical(nid)node["condition"]=self.split_condition(nid)ifself.is_leaf(nid):node["weight"]=self.split_condition(nid)string="\n".join(map(lambdax:" "+str(x),nodes))returnstringclassModel:"""Gradient boosted tree model."""def__init__(self,model:dict)->None:"""Construct the Model from a JSON object. parameters ---------- model : A dictionary loaded by json representing a XGBoost boosted tree model. """# Basic properties of a modelself.learner_model_shape:ParamT=model["learner"]["learner_model_param"]self.num_output_group=int(self.learner_model_shape["num_class"])self.num_feature=int(self.learner_model_shape["num_feature"])self.base_score:List[float]=json.loads(self.learner_model_shape["base_score"])# A field encoding which output group a tree belongsself.tree_info=model["learner"]["gradient_booster"]["model"]["tree_info"]model_shape:ParamT=model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]# JSON representation of treesj_trees=model["learner"]["gradient_booster"]["model"]["trees"]# Load the treesself.num_trees=int(model_shape["num_trees"])trees:List[Tree]=[]foriinrange(self.num_trees):tree:Dict[str,Any]=j_trees[i]tree_id=int(tree["id"])asserttree_id==i,(tree_id,i)# - propertiesleft_children:List[int]=tree["left_children"]right_children:List[int]=tree["right_children"]parents:List[int]=tree["parents"]split_conditions:List[float]=tree["split_conditions"]split_indices:List[int]=tree["split_indices"]# when ubjson is used, this is a byte array with each element as uint8default_left=to_integers(tree["default_left"])# - categorical features# when ubjson is used, this is a byte array with each element as uint8split_types=to_integers(tree["split_type"])# categories for each node is stored in a CSR style storage with segment as# the begin ptr and the `categories' as values.cat_segments:List[int]=tree["categories_segments"]cat_sizes:List[int]=tree["categories_sizes"]# node index for categorical nodescat_nodes:List[int]=tree["categories_nodes"]assertlen(cat_segments)==len(cat_sizes)==len(cat_nodes)cats=tree["categories"]assertlen(left_children)==len(split_types)# The storage for categories is only defined for categorical nodes to# prevent unnecessary overhead for numerical splits, we track the# categorical node that are processed using a counter.cat_cnt=0ifcat_nodes:last_cat_node=cat_nodes[cat_cnt]else:last_cat_node=-1node_categories:List[List[int]]=[]fornode_idinrange(len(left_children)):ifnode_id==last_cat_node:beg=cat_segments[cat_cnt]size=cat_sizes[cat_cnt]end=beg+sizenode_cats=cats[beg:end]# categories are unique for each nodeassertlen(set(node_cats))==len(node_cats)cat_cnt+=1ifcat_cnt==len(cat_nodes):last_cat_node=-1# continue to process the rest of the nodeselse:last_cat_node=cat_nodes[cat_cnt]assertnode_catsnode_categories.append(node_cats)else:# append an empty node, it's either a numerical node or a leaf.node_categories.append([])# - statsbase_weights:List[float]=tree["base_weights"]loss_changes:List[float]=tree["loss_changes"]sum_hessian:List[float]=tree["sum_hessian"]# Construct a list of nodes that have complete informationnodes:List[Node]=[Node(left_children[node_id],right_children[node_id],parents[node_id],split_indices[node_id],split_conditions[node_id],default_left[node_id]==1,# to booleanSplitType(split_types[node_id]),node_categories[node_id],base_weights[node_id],loss_changes[node_id],sum_hessian[node_id],)fornode_idinrange(len(left_children))]pytree=Tree(tree_id,nodes)trees.append(pytree)self.trees=treesdefprint_model(self)->None:fori,treeinenumerate(self.trees):print("\ntree_id:",i)print(tree)if__name__=="__main__":parser=argparse.ArgumentParser(description="Demonstration for loading XGBoost JSON/UBJSON model.")parser.add_argument("--model",type=str,required=True,help="Path to .json/.ubj model file.")args=parser.parse_args()ifargs.model.endswith("json"):# use json formatwithopen(args.model,"r")asfd:model=json.load(fd)elifargs.model.endswith("ubj"):ifubjsonisNone:raiseImportError("ubjson is not installed.")# use ubjson formatwithopen(args.model,"rb")asbfd:model=ubjson.load(bfd)else:raiseValueError("Unexpected file extension. Supported file extension are json and ubj.")model=Model(model)model.print_model()