|
29 | 29 | importtorch._C |
30 | 30 | fromtorch._guardsimportGuard |
31 | 31 |
|
32 | | -from ..importvariables |
| 32 | +from ..importgraph_break_hints,variables |
33 | 33 | from ..bytecode_transformationimport ( |
34 | 34 | create_call_function, |
35 | 35 | create_instruction, |
36 | 36 | create_setup_with, |
37 | 37 | ) |
38 | 38 | from ..device_interfaceimportget_interface_for_device |
39 | | -from ..excimportunimplemented,Unsupported |
| 39 | +from ..excimportunimplemented_v2 |
40 | 40 | from ..guardsimportGuardBuilder,install_guard |
41 | 41 | from ..sourceimportAttrSource,GlobalStateSource |
42 | 42 | from .baseimportVariableTracker |
@@ -173,40 +173,27 @@ def fn_name(self): |
173 | 173 |
|
174 | 174 | defenter(self,tx): |
175 | 175 | source=Noneifself.sourceisNoneelseAttrSource(self.source,"__enter__") |
176 | | -try: |
177 | | -returnvariables.UserMethodVariable( |
178 | | -self.cm_obj.__enter__.__func__, |
179 | | -self, |
180 | | -source=source, |
181 | | - ).call_function(tx, [], {}) |
182 | | -exceptUnsupportedase: |
183 | | -unimplemented( |
184 | | -f"Unsupported context manager{self.cm_obj}'s __enter__ function", |
185 | | -from_exc=e, |
186 | | - ) |
| 176 | +returnvariables.UserMethodVariable( |
| 177 | +self.cm_obj.__enter__.__func__, |
| 178 | +self, |
| 179 | +source=source, |
| 180 | + ).call_function(tx, [], {}) |
187 | 181 |
|
188 | 182 | defexit(self,tx:"InstructionTranslator",*args): |
189 | 183 | source=Noneifself.sourceisNoneelseAttrSource(self.source,"__exit__") |
190 | | -try: |
191 | | -x=variables.UserMethodVariable( |
192 | | -self.cm_obj.__exit__.__func__, |
193 | | -self, |
194 | | -source=source, |
195 | | - ).call_function( |
196 | | -tx, |
197 | | - [ |
198 | | -variables.ConstantVariable.create(None), |
199 | | -variables.ConstantVariable.create(None), |
200 | | -variables.ConstantVariable.create(None), |
201 | | - ], |
202 | | - {}, |
203 | | - ) |
204 | | -exceptUnsupportedase: |
205 | | -unimplemented( |
206 | | -f"Unsupported context manager{self.cm_obj}'s __exit__ function", |
207 | | -from_exc=e, |
208 | | - ) |
209 | | - |
| 184 | +x=variables.UserMethodVariable( |
| 185 | +self.cm_obj.__exit__.__func__, |
| 186 | +self, |
| 187 | +source=source, |
| 188 | + ).call_function( |
| 189 | +tx, |
| 190 | + [ |
| 191 | +variables.ConstantVariable.create(None), |
| 192 | +variables.ConstantVariable.create(None), |
| 193 | +variables.ConstantVariable.create(None), |
| 194 | + ], |
| 195 | + {}, |
| 196 | + ) |
210 | 197 | tx.active_generic_context_managers.pop() |
211 | 198 | returnx |
212 | 199 |
|
@@ -921,11 +908,13 @@ def fn_name(self): |
921 | 908 | return"nullcontext" |
922 | 909 |
|
923 | 910 | defreconstruct(self,cg): |
924 | | -unimplemented( |
925 | | -""" |
926 | | -Dynamo doesn't support compiling a region that leaks torch profiler context |
927 | | -objects which will be used outside the region |
928 | | -""" |
| 911 | +unimplemented_v2( |
| 912 | +gb_type="torch.profiler object escaped from compiled region", |
| 913 | +context=str(self), |
| 914 | +explanation="Dynamo doesn't support compiling a region that returns a torch.profiler context manager.", |
| 915 | +hints=[ |
| 916 | +*graph_break_hints.SUPPORTABLE, |
| 917 | + ], |
929 | 918 | ) |
930 | 919 |
|
931 | 920 |
|
@@ -1043,8 +1032,16 @@ def exit(self, tx: "InstructionTranslator", *args): |
1043 | 1032 | ).call_function(tx, [self.tensors,self.prev_versions], {}) |
1044 | 1033 |
|
1045 | 1034 | defreconstruct(self,codegen): |
1046 | | -unimplemented( |
1047 | | -"torch.autograd._unsafe_preserve_version_counter with graph break" |
| 1035 | +unimplemented_v2( |
| 1036 | +gb_type="torch.autograd._unsafe_preserve_version_counter escaped from compiled region", |
| 1037 | +context=str(self), |
| 1038 | +explanation=( |
| 1039 | +"Dynamo doesn't support compiling a region that returns " |
| 1040 | +"a torch.autograd._unsafe_preserve_version_counter context manager." |
| 1041 | + ), |
| 1042 | +hints=[ |
| 1043 | +*graph_break_hints.SUPPORTABLE, |
| 1044 | + ], |
1048 | 1045 | ) |
1049 | 1046 |
|
1050 | 1047 |
|
@@ -1292,7 +1289,17 @@ def call_method( |
1292 | 1289 | ), |
1293 | 1290 | ) |
1294 | 1291 | else: |
1295 | | -unimplemented(f"event method{name} unsupported") |
| 1292 | +unimplemented_v2( |
| 1293 | +gb_type="Unsupported torch.cuda.Event method", |
| 1294 | +context=str(name), |
| 1295 | +explanation=( |
| 1296 | +f"Dynamo doesn't support tracing the torch.cuda.Event.{name} method. " |
| 1297 | +f"We currently support wait, record, synchronize, and query.", |
| 1298 | + ), |
| 1299 | +hints=[ |
| 1300 | +*graph_break_hints.SUPPORTABLE, |
| 1301 | + ], |
| 1302 | + ) |
1296 | 1303 |
|
1297 | 1304 | defas_proxy(self): |
1298 | 1305 | returnself.proxy |
|