Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitc28aeb2

Browse files
committed
fix onnx_text_plot_tree
1 parent674eb27 commitc28aeb2

File tree

3 files changed

+64
-19
lines changed

3 files changed

+64
-19
lines changed
1.38 KB
Binary file not shown.

‎_unittests/ut_plotting/test_text_plot.py‎

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_onnx_text_plot_tree_reg(self):
5252
onx=to_onnx(clr,X)
5353
res=onnx_text_plot_tree(onx.graph.node[0])
5454
self.assertIn("treeid=0",res)
55-
self.assertIn("T y=",res)
55+
self.assertIn("+f",res)
5656

5757
deftest_onnx_text_plot_tree_cls(self):
5858
iris=load_iris()
@@ -62,20 +62,43 @@ def test_onnx_text_plot_tree_cls(self):
6262
onx=to_onnx(clr,X)
6363
res=onnx_text_plot_tree(onx.graph.node[0])
6464
self.assertIn("treeid=0",res)
65-
self.assertIn("T y=",res)
65+
self.assertIn("+f 0:",res)
6666
self.assertIn("n_classes=3",res)
6767

6868
deftest_onnx_text_plot_tree_cls_2(self):
69-
iris=load_iris()
70-
X_train,y_train=iris.data.astype(numpy.float32),iris.target
71-
clr=DecisionTreeClassifier()
72-
clr.fit(X_train,y_train)
73-
model_def=to_onnx(
74-
clr,X_train.astype(numpy.float32),options={"zipmap":False}
69+
this=os.path.join(
70+
os.path.dirname(__file__),"data","onnx_text_plot_tree_cls_2.onnx"
7571
)
72+
withopen(this,"rb")asf:
73+
model_def=load(f)
7674
res=onnx_text_plot_tree(model_def.graph.node[0])
7775
self.assertIn("n_classes=3",res)
78-
print(res)
76+
expected=textwrap.dedent(
77+
"""
78+
n_classes=3
79+
n_trees=1
80+
----
81+
treeid=0
82+
n X2 <= 2.4499998
83+
-n X3 <= 1.75
84+
-n X2 <= 4.85
85+
-f 0:0 1:0 2:1
86+
+n X0 <= 5.95
87+
-f 0:0 1:0 2:1
88+
+f 0:0 1:1 2:0
89+
+n X2 <= 4.95
90+
-n X3 <= 1.55
91+
-n X0 <= 6.95
92+
-f 0:0 1:0 2:1
93+
+f 0:0 1:1 2:0
94+
+f 0:0 1:0 2:1
95+
+n X3 <= 1.65
96+
-f 0:0 1:0 2:1
97+
+f 0:0 1:1 2:0
98+
+f 0:1 1:0 2:0
99+
"""
100+
).strip("\n\r")
101+
self.assertEqual(expected,res.strip("\n\r"))
79102

80103
@ignore_warnings((UserWarning,FutureWarning))
81104
deftest_onnx_simple_text_plot_kmeans(self):

‎onnx_array_api/plotting/text_plot.py‎

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ def _rule(r):
2424
raiseValueError(f"Unexpected rule{r!r}.")
2525

2626

27+
def_number2str(i):
28+
ifisinstance(i,int):
29+
returnstr(i)
30+
ifint(i)==i:
31+
returnstr(int(i))
32+
returnf"{i:1.2f}"
33+
34+
2735
defonnx_text_plot_tree(node):
2836
"""
2937
Gives a textual representation of a tree ensemble.
@@ -61,18 +69,32 @@ def __init__(self, i, atts):
6169
setattr(self,k,v[i])
6270
self.depth=0
6371
self.true_false=""
72+
self.targets= []
73+
74+
defappend_target(self,tid,weight):
75+
self.targets.append(dict(target_id=tid,weight=weight))
6476

6577
defprocess_node(self):
6678
"node to string"
6779
ifself.nodes_modes=="LEAF":
68-
text="%s y=%r f=%r i=%r"% (
69-
self.true_false,
70-
self.target_weights,
71-
self.target_ids,
72-
self.target_nodeids,
73-
)
80+
iflen(self.targets)==0:
81+
text=f"{self.true_false}f"
82+
eliflen(self.targets)==1:
83+
t=self.targets[0]
84+
text= (
85+
f"{self.true_false}f "
86+
f"{t['target_id']}:{_number2str(t['weight'])}"
87+
)
88+
else:
89+
ts=" ".join(
90+
map(
91+
lambdat:f"{t['target_id']}:{_number2str(t['weight'])}",
92+
self.targets,
93+
)
94+
)
95+
text=f"{self.true_false}f{ts}"
7496
else:
75-
text="%s X%d %s %r"% (
97+
text="%sn X%d %s %r"% (
7698
self.true_false,
7799
self.nodes_featureids,
78100
_rule(self.nodes_modes),
@@ -115,7 +137,7 @@ def process_tree(atts, treeid):
115137
idn=short[f"{prefix}_nodeids"][i]
116138
node=nodes[idn]
117139
node.append_target(
118-
id=short[f"{prefix}_ids"][i],weight=short[f"{prefix}_weights"][i]
140+
tid=short[f"{prefix}_ids"][i],weight=short[f"{prefix}_weights"][i]
119141
)
120142

121143
defiterate(nodes,node,depth=0,true_false=""):
@@ -127,14 +149,14 @@ def iterate(nodes, node, depth=0, true_false=""):
127149
nodes,
128150
nodes[node.nodes_falsenodeids],
129151
depth=depth+1,
130-
true_false="F",
152+
true_false="-",
131153
):
132154
yieldn
133155
forniniterate(
134156
nodes,
135157
nodes[node.nodes_truenodeids],
136158
depth=depth+1,
137-
true_false="T",
159+
true_false="+",
138160
):
139161
yieldn
140162

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp