2121from typing import Any ,cast
2222
2323from pytensor import function
24- from pytensor .graph .basic import ancestors ,walk
24+ from pytensor .graph .basic import Variable , ancestors ,walk
2525from pytensor .tensor .shape import Shape
26- from pytensor .tensor .variable import TensorVariable
2726
2827from pymc .model .core import modelcontext
28+ from pymc .pytensorf import _cheap_eval_mode
2929from pymc .util import VarName ,get_default_varnames ,get_var_name
3030
3131__all__ = (
@@ -73,7 +73,7 @@ def create_plate_label_with_dim_length(
7373
7474
7575def fast_eval (var ):
76- return function ([],var ,mode = "FAST_COMPILE" )()
76+ return function ([],var ,mode = _cheap_eval_mode )()
7777
7878
7979class NodeType (str ,Enum ):
@@ -88,7 +88,7 @@ class NodeType(str, Enum):
8888
8989@dataclass
9090class NodeInfo :
91- var :TensorVariable
91+ var :Variable
9292node_type :NodeType
9393
9494def __hash__ (self ):
@@ -108,10 +108,10 @@ def __eq__(self, other) -> bool:
108108
109109
110110GraphvizNodeKwargs = dict [str ,Any ]
111- NodeFormatter = Callable [[TensorVariable ],GraphvizNodeKwargs ]
111+ NodeFormatter = Callable [[Variable ],GraphvizNodeKwargs ]
112112
113113
114- def default_potential (var :TensorVariable )-> GraphvizNodeKwargs :
114+ def default_potential (var :Variable )-> GraphvizNodeKwargs :
115115"""Return default data for potential in the graph."""
116116return {
117117"shape" :"octagon" ,
@@ -120,17 +120,19 @@ def default_potential(var: TensorVariable) -> GraphvizNodeKwargs:
120120 }
121121
122122
123- def random_variable_symbol (var :TensorVariable )-> str :
123+ def random_variable_symbol (var :Variable )-> str :
124124"""Get the symbol of the random variable."""
125- symbol = var .owner .op . __class__ . __name__
125+ op = var .owner .op
126126
127- if symbol .endswith ("RV" ):
128- symbol = symbol [:- 2 ]
127+ if name := getattr (op ,"name" ,None ):
128+ symbol = name [0 ].upper ()+ name [1 :]
129+ else :
130+ symbol = op .__class__ .__name__ .removesuffix ("RV" )
129131
130132return symbol
131133
132134
133- def default_free_rv (var :TensorVariable )-> GraphvizNodeKwargs :
135+ def default_free_rv (var :Variable )-> GraphvizNodeKwargs :
134136"""Return default data for free RV in the graph."""
135137symbol = random_variable_symbol (var )
136138
@@ -141,7 +143,7 @@ def default_free_rv(var: TensorVariable) -> GraphvizNodeKwargs:
141143 }
142144
143145
144- def default_observed_rv (var :TensorVariable )-> GraphvizNodeKwargs :
146+ def default_observed_rv (var :Variable )-> GraphvizNodeKwargs :
145147"""Return default data for observed RV in the graph."""
146148symbol = random_variable_symbol (var )
147149
@@ -152,7 +154,7 @@ def default_observed_rv(var: TensorVariable) -> GraphvizNodeKwargs:
152154 }
153155
154156
155- def default_deterministic (var :TensorVariable )-> GraphvizNodeKwargs :
157+ def default_deterministic (var :Variable )-> GraphvizNodeKwargs :
156158"""Return default data for the deterministic in the graph."""
157159return {
158160"shape" :"box" ,
@@ -161,7 +163,7 @@ def default_deterministic(var: TensorVariable) -> GraphvizNodeKwargs:
161163 }
162164
163165
164- def default_data (var :TensorVariable )-> GraphvizNodeKwargs :
166+ def default_data (var :Variable )-> GraphvizNodeKwargs :
165167"""Return default data for the data in the graph."""
166168return {
167169"shape" :"box" ,
@@ -239,7 +241,7 @@ def __init__(self, model):
239241self ._all_vars = {model [var_name ]for var_name in self ._all_var_names }
240242self .var_list = self .model .named_vars .values ()
241243
242- def get_parent_names (self ,var :TensorVariable )-> set [VarName ]:
244+ def get_parent_names (self ,var :Variable )-> set [VarName ]:
243245if var .owner is None :
244246return set ()
245247
@@ -345,7 +347,7 @@ def get_plates(
345347dim_name :fast_eval (value ).item ()for dim_name ,value in self .model .dim_lengths .items ()
346348 }
347349var_shapes :dict [str ,tuple [int , ...]]= {
348- var_name :tuple (fast_eval (self .model [var_name ].shape ))
350+ var_name :tuple (map ( int , fast_eval (self .model [var_name ].shape ) ))
349351for var_name in self .vars_to_plot (var_names )
350352 }
351353