Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit42f1905

Browse files
authored
Fix _is_tensorflow_array. (#30114)
The previous implementation was clearly wrong (the isinstance checkwould raise TypeError as the second argument would be a bool), but thetests didn't catch that because the bug led to _is_tensorflow_arrayreturning False, then _unpack_to_numpy returning the original input,and then assert_array_equal implicitly converting `result` by calling`__array__` on it. Fix the test by explicitly checking that `result`is indeed a numpy array, and also fix _is_tensorflow_array with morerestrictive exception catching (also applied to _is_torch_array,_is_jax_array, and _is_pandas_dataframe, while we're at it).
1 parentb18407b commit42f1905

File tree

2 files changed

+43
-38
lines changed

2 files changed

+43
-38
lines changed

‎lib/matplotlib/cbook.py

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2331,42 +2331,56 @@ def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class):
23312331

23322332

23332333
def_is_torch_array(x):
2334-
"""Check if 'x' is a PyTorch Tensor."""
2334+
"""Return whether *x* is a PyTorch Tensor."""
23352335
try:
2336-
#we're intentionally not attempting to import torch. If somebody
2337-
# has created a torch array, torch should already be in sys.modules
2338-
returnisinstance(x,sys.modules['torch'].Tensor)
2339-
exceptException:# TypeError, KeyError,AttributeError, maybe others?
2340-
# we're attempting to access attributes on imported modules which
2341-
#may have arbitrary user code, so we deliberately catch all exceptions
2342-
returnFalse
2336+
#We're intentionally not attempting to import torch. If somebody
2337+
# has created a torch array, torch should already be in sys.modules.
2338+
tp=sys.modules.get("torch").Tensor
2339+
exceptAttributeError:
2340+
returnFalse# Module not imported or a nonstandard module with no Tensor attr.
2341+
return (isinstance(tp,type)#Just in case it's a very nonstandard module.
2342+
andisinstance(x,tp))
23432343

23442344

23452345
def_is_jax_array(x):
2346-
"""Check if 'x' is a JAX Array."""
2346+
"""Return whether *x* is a JAX Array."""
23472347
try:
2348-
# we're intentionally not attempting to import jax. If somebody
2349-
# has created a jax array, jax should already be in sys.modules
2350-
returnisinstance(x,sys.modules['jax'].Array)
2351-
exceptException:# TypeError, KeyError, AttributeError, maybe others?
2352-
# we're attempting to access attributes on imported modules which
2353-
# may have arbitrary user code, so we deliberately catch all exceptions
2354-
returnFalse
2348+
# We're intentionally not attempting to import jax. If somebody
2349+
# has created a jax array, jax should already be in sys.modules.
2350+
tp=sys.modules.get("jax").Array
2351+
exceptAttributeError:
2352+
returnFalse# Module not imported or a nonstandard module with no Array attr.
2353+
return (isinstance(tp,type)# Just in case it's a very nonstandard module.
2354+
andisinstance(x,tp))
2355+
2356+
2357+
def_is_pandas_dataframe(x):
2358+
"""Check if *x* is a Pandas DataFrame."""
2359+
try:
2360+
# We're intentionally not attempting to import Pandas. If somebody
2361+
# has created a Pandas DataFrame, Pandas should already be in sys.modules.
2362+
tp=sys.modules.get("pandas").DataFrame
2363+
exceptAttributeError:
2364+
returnFalse# Module not imported or a nonstandard module with no Array attr.
2365+
return (isinstance(tp,type)# Just in case it's a very nonstandard module.
2366+
andisinstance(x,tp))
23552367

23562368

23572369
def_is_tensorflow_array(x):
2358-
"""Check if 'x' is a TensorFlow Tensor or Variable."""
2370+
"""Return whether *x* is a TensorFlow Tensor or Variable."""
23592371
try:
2360-
# we're intentionally not attempting to import TensorFlow. If somebody
2361-
# has created a TensorFlow array, TensorFlow should already be in sys.modules
2362-
# we use `is_tensor` to not depend on the class structure of TensorFlow
2363-
# arrays, as `tf.Variables` are not instances of `tf.Tensor`
2364-
# (they both convert the same way)
2365-
returnisinstance(x,sys.modules['tensorflow'].is_tensor(x))
2366-
exceptException:# TypeError, KeyError, AttributeError, maybe others?
2367-
# we're attempting to access attributes on imported modules which
2368-
# may have arbitrary user code, so we deliberately catch all exceptions
2372+
# We're intentionally not attempting to import TensorFlow. If somebody
2373+
# has created a TensorFlow array, TensorFlow should already be in
2374+
# sys.modules we use `is_tensor` to not depend on the class structure
2375+
# of TensorFlow arrays, as `tf.Variables` are not instances of
2376+
# `tf.Tensor` (they both convert the same way).
2377+
is_tensor=sys.modules.get("tensorflow").is_tensor
2378+
exceptAttributeError:
23692379
returnFalse
2380+
try:
2381+
returnis_tensor(x)
2382+
exceptException:
2383+
returnFalse# Just in case it's a very nonstandard module.
23702384

23712385

23722386
def_unpack_to_numpy(x):
@@ -2421,15 +2435,3 @@ def _auto_format_str(fmt, value):
24212435
returnfmt% (value,)
24222436
except (TypeError,ValueError):
24232437
returnfmt.format(value)
2424-
2425-
2426-
def_is_pandas_dataframe(x):
2427-
"""Check if 'x' is a Pandas DataFrame."""
2428-
try:
2429-
# we're intentionally not attempting to import Pandas. If somebody
2430-
# has created a Pandas DataFrame, Pandas should already be in sys.modules
2431-
returnisinstance(x,sys.modules['pandas'].DataFrame)
2432-
exceptException:# TypeError, KeyError, AttributeError, maybe others?
2433-
# we're attempting to access attributes on imported modules which
2434-
# may have arbitrary user code, so we deliberately catch all exceptions
2435-
returnFalse

‎lib/matplotlib/tests/test_cbook.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,7 @@ def __array__(self):
10001000
torch_tensor=torch.Tensor(data)
10011001

10021002
result=cbook._unpack_to_numpy(torch_tensor)
1003+
assertisinstance(result,np.ndarray)
10031004
# compare results, do not check for identity: the latter would fail
10041005
# if not mocked, and the implementation does not guarantee it
10051006
# is the same Python object, just the same values.
@@ -1028,6 +1029,7 @@ def __array__(self):
10281029
jax_array=jax.Array(data)
10291030

10301031
result=cbook._unpack_to_numpy(jax_array)
1032+
assertisinstance(result,np.ndarray)
10311033
# compare results, do not check for identity: the latter would fail
10321034
# if not mocked, and the implementation does not guarantee it
10331035
# is the same Python object, just the same values.
@@ -1057,6 +1059,7 @@ def __array__(self):
10571059
tf_tensor=tensorflow.Tensor(data)
10581060

10591061
result=cbook._unpack_to_numpy(tf_tensor)
1062+
assertisinstance(result,np.ndarray)
10601063
# compare results, do not check for identity: the latter would fail
10611064
# if not mocked, and the implementation does not guarantee it
10621065
# is the same Python object, just the same values.

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp