Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit7a734f4

Browse files
committed
Tweaks to model_graph to play nice with XTensorVariables
* Use RV Op name when provided* More robust detection of observed data variables (after#7656 arbitrary graphs are allowed)* Remove self loops explicitly (closes#7722)
1 parentc3cc833 commit7a734f4

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

‎pymc/model_graph.py‎

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
fromtypingimportAny,cast
2222

2323
frompytensorimportfunction
24-
frompytensor.graph.basicimportancestors,walk
24+
frompytensor.graph.basicimportVariable,ancestors,walk
2525
frompytensor.tensor.shapeimportShape
26-
frompytensor.tensor.variableimportTensorVariable
2726

2827
frompymc.model.coreimportmodelcontext
28+
frompymc.pytensorfimport_cheap_eval_mode
2929
frompymc.utilimportVarName,get_default_varnames,get_var_name
3030

3131
__all__= (
@@ -73,7 +73,7 @@ def create_plate_label_with_dim_length(
7373

7474

7575
deffast_eval(var):
76-
returnfunction([],var,mode="FAST_COMPILE")()
76+
returnfunction([],var,mode=_cheap_eval_mode)()
7777

7878

7979
classNodeType(str,Enum):
@@ -88,7 +88,7 @@ class NodeType(str, Enum):
8888

8989
@dataclass
9090
classNodeInfo:
91-
var:TensorVariable
91+
var:Variable
9292
node_type:NodeType
9393

9494
def__hash__(self):
@@ -108,10 +108,10 @@ def __eq__(self, other) -> bool:
108108

109109

110110
GraphvizNodeKwargs=dict[str,Any]
111-
NodeFormatter=Callable[[TensorVariable],GraphvizNodeKwargs]
111+
NodeFormatter=Callable[[Variable],GraphvizNodeKwargs]
112112

113113

114-
defdefault_potential(var:TensorVariable)->GraphvizNodeKwargs:
114+
defdefault_potential(var:Variable)->GraphvizNodeKwargs:
115115
"""Return default data for potential in the graph."""
116116
return {
117117
"shape":"octagon",
@@ -120,17 +120,19 @@ def default_potential(var: TensorVariable) -> GraphvizNodeKwargs:
120120
}
121121

122122

123-
defrandom_variable_symbol(var:TensorVariable)->str:
123+
defrandom_variable_symbol(var:Variable)->str:
124124
"""Get the symbol of the random variable."""
125-
symbol=var.owner.op.__class__.__name__
125+
op=var.owner.op
126126

127-
ifsymbol.endswith("RV"):
128-
symbol=symbol[:-2]
127+
ifname:=getattr(op,"name",None):
128+
symbol=name[0].upper()+name[1:]
129+
else:
130+
symbol=op.__class__.__name__.removesuffix("RV")
129131

130132
returnsymbol
131133

132134

133-
defdefault_free_rv(var:TensorVariable)->GraphvizNodeKwargs:
135+
defdefault_free_rv(var:Variable)->GraphvizNodeKwargs:
134136
"""Return default data for free RV in the graph."""
135137
symbol=random_variable_symbol(var)
136138

@@ -141,7 +143,7 @@ def default_free_rv(var: TensorVariable) -> GraphvizNodeKwargs:
141143
}
142144

143145

144-
defdefault_observed_rv(var:TensorVariable)->GraphvizNodeKwargs:
146+
defdefault_observed_rv(var:Variable)->GraphvizNodeKwargs:
145147
"""Return default data for observed RV in the graph."""
146148
symbol=random_variable_symbol(var)
147149

@@ -152,7 +154,7 @@ def default_observed_rv(var: TensorVariable) -> GraphvizNodeKwargs:
152154
}
153155

154156

155-
defdefault_deterministic(var:TensorVariable)->GraphvizNodeKwargs:
157+
defdefault_deterministic(var:Variable)->GraphvizNodeKwargs:
156158
"""Return default data for the deterministic in the graph."""
157159
return {
158160
"shape":"box",
@@ -161,7 +163,7 @@ def default_deterministic(var: TensorVariable) -> GraphvizNodeKwargs:
161163
}
162164

163165

164-
defdefault_data(var:TensorVariable)->GraphvizNodeKwargs:
166+
defdefault_data(var:Variable)->GraphvizNodeKwargs:
165167
"""Return default data for the data in the graph."""
166168
return {
167169
"shape":"box",
@@ -239,7 +241,7 @@ def __init__(self, model):
239241
self._all_vars= {model[var_name]forvar_nameinself._all_var_names}
240242
self.var_list=self.model.named_vars.values()
241243

242-
defget_parent_names(self,var:TensorVariable)->set[VarName]:
244+
defget_parent_names(self,var:Variable)->set[VarName]:
243245
ifvar.ownerisNone:
244246
returnset()
245247

@@ -345,7 +347,7 @@ def get_plates(
345347
dim_name:fast_eval(value).item()fordim_name,valueinself.model.dim_lengths.items()
346348
}
347349
var_shapes:dict[str,tuple[int, ...]]= {
348-
var_name:tuple(fast_eval(self.model[var_name].shape))
350+
var_name:tuple(map(int,fast_eval(self.model[var_name].shape)))
349351
forvar_nameinself.vars_to_plot(var_names)
350352
}
351353

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp