@@ -64,7 +64,10 @@ def __init__(self, i, atts):
6464self .nodes_missing_value_tracks_true = None
6565for k ,v in atts .items ():
6666if k .startswith ("nodes" ):
67- setattr (self ,k ,v [i ])
67+ if k .endswith ("_as_tensor" ):
68+ setattr (self ,k .replace ("_as_tensor" ,"" ),v [i ])
69+ else :
70+ setattr (self ,k ,v [i ])
6871self .depth = 0
6972self .true_false = ""
7073self .targets = []
@@ -120,10 +123,7 @@ def process_tree(atts, treeid):
120123 ]
121124for k ,v in atts .items ():
122125if k .startswith (prefix ):
123- if "classlabels" in k :
124- short [k ]= list (v )
125- else :
126- short [k ]= [v [i ]for i in idx ]
126+ short [k ]= list (v )if "classlabels" in k else [v [i ]for i in idx ]
127127
128128nodes = OrderedDict ()
129129for i in range (len (short ["nodes_treeids" ])):
@@ -132,9 +132,10 @@ def process_tree(atts, treeid):
132132for i in range (len (short [f"{ prefix } _treeids" ])):
133133idn = short [f"{ prefix } _nodeids" ][i ]
134134node = nodes [idn ]
135- node .append_target (
136- tid = short [f"{ prefix } _ids" ][i ],weight = short [f"{ prefix } _weights" ][i ]
137- )
135+ key = f"{ prefix } _weights"
136+ if key not in short :
137+ key = f"{ prefix } _weights_as_tensor"
138+ node .append_target (tid = short [f"{ prefix } _ids" ][i ],weight = short [key ][i ])
138139
139140def iterate (nodes ,node ,depth = 0 ,true_false = "" ):
140141node .depth = depth