Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit0c10eef

Browse files
Backport PR#30114: Fix _is_tensorflow_array. (#30120)
Co-authored-by: Antony Lee <anntzer.lee@gmail.com>
1 parent6557371 commit0c10eef

File tree

2 files changed

+43
-38
lines changed

2 files changed

+43
-38
lines changed

‎lib/matplotlib/cbook.py

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2311,42 +2311,56 @@ def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class):
23112311

23122312

23132313
def_is_torch_array(x):
2314-
"""Check if 'x' is a PyTorch Tensor."""
2314+
"""Return whether *x* is a PyTorch Tensor."""
23152315
try:
2316-
#we're intentionally not attempting to import torch. If somebody
2317-
# has created a torch array, torch should already be in sys.modules
2318-
returnisinstance(x,sys.modules['torch'].Tensor)
2319-
exceptException:# TypeError, KeyError,AttributeError, maybe others?
2320-
# we're attempting to access attributes on imported modules which
2321-
#may have arbitrary user code, so we deliberately catch all exceptions
2322-
returnFalse
2316+
#We're intentionally not attempting to import torch. If somebody
2317+
# has created a torch array, torch should already be in sys.modules.
2318+
tp=sys.modules.get("torch").Tensor
2319+
exceptAttributeError:
2320+
returnFalse# Module not imported or a nonstandard module with no Tensor attr.
2321+
return (isinstance(tp,type)#Just in case it's a very nonstandard module.
2322+
andisinstance(x,tp))
23232323

23242324

23252325
def_is_jax_array(x):
2326-
"""Check if 'x' is a JAX Array."""
2326+
"""Return whether *x* is a JAX Array."""
23272327
try:
2328-
# we're intentionally not attempting to import jax. If somebody
2329-
# has created a jax array, jax should already be in sys.modules
2330-
returnisinstance(x,sys.modules['jax'].Array)
2331-
exceptException:# TypeError, KeyError, AttributeError, maybe others?
2332-
# we're attempting to access attributes on imported modules which
2333-
# may have arbitrary user code, so we deliberately catch all exceptions
2334-
returnFalse
2328+
# We're intentionally not attempting to import jax. If somebody
2329+
# has created a jax array, jax should already be in sys.modules.
2330+
tp=sys.modules.get("jax").Array
2331+
exceptAttributeError:
2332+
returnFalse# Module not imported or a nonstandard module with no Array attr.
2333+
return (isinstance(tp,type)# Just in case it's a very nonstandard module.
2334+
andisinstance(x,tp))
2335+
2336+
2337+
def_is_pandas_dataframe(x):
2338+
"""Check if *x* is a Pandas DataFrame."""
2339+
try:
2340+
# We're intentionally not attempting to import Pandas. If somebody
2341+
# has created a Pandas DataFrame, Pandas should already be in sys.modules.
2342+
tp=sys.modules.get("pandas").DataFrame
2343+
exceptAttributeError:
2344+
returnFalse# Module not imported or a nonstandard module with no Array attr.
2345+
return (isinstance(tp,type)# Just in case it's a very nonstandard module.
2346+
andisinstance(x,tp))
23352347

23362348

23372349
def_is_tensorflow_array(x):
2338-
"""Check if 'x' is a TensorFlow Tensor or Variable."""
2350+
"""Return whether *x* is a TensorFlow Tensor or Variable."""
23392351
try:
2340-
# we're intentionally not attempting to import TensorFlow. If somebody
2341-
# has created a TensorFlow array, TensorFlow should already be in sys.modules
2342-
# we use `is_tensor` to not depend on the class structure of TensorFlow
2343-
# arrays, as `tf.Variables` are not instances of `tf.Tensor`
2344-
# (they both convert the same way)
2345-
returnisinstance(x,sys.modules['tensorflow'].is_tensor(x))
2346-
exceptException:# TypeError, KeyError, AttributeError, maybe others?
2347-
# we're attempting to access attributes on imported modules which
2348-
# may have arbitrary user code, so we deliberately catch all exceptions
2352+
# We're intentionally not attempting to import TensorFlow. If somebody
2353+
# has created a TensorFlow array, TensorFlow should already be in
2354+
# sys.modules we use `is_tensor` to not depend on the class structure
2355+
# of TensorFlow arrays, as `tf.Variables` are not instances of
2356+
# `tf.Tensor` (they both convert the same way).
2357+
is_tensor=sys.modules.get("tensorflow").is_tensor
2358+
exceptAttributeError:
23492359
returnFalse
2360+
try:
2361+
returnis_tensor(x)
2362+
exceptException:
2363+
returnFalse# Just in case it's a very nonstandard module.
23502364

23512365

23522366
def_unpack_to_numpy(x):
@@ -2401,15 +2415,3 @@ def _auto_format_str(fmt, value):
24012415
returnfmt% (value,)
24022416
except (TypeError,ValueError):
24032417
returnfmt.format(value)
2404-
2405-
2406-
def_is_pandas_dataframe(x):
2407-
"""Check if 'x' is a Pandas DataFrame."""
2408-
try:
2409-
# we're intentionally not attempting to import Pandas. If somebody
2410-
# has created a Pandas DataFrame, Pandas should already be in sys.modules
2411-
returnisinstance(x,sys.modules['pandas'].DataFrame)
2412-
exceptException:# TypeError, KeyError, AttributeError, maybe others?
2413-
# we're attempting to access attributes on imported modules which
2414-
# may have arbitrary user code, so we deliberately catch all exceptions
2415-
returnFalse

‎lib/matplotlib/tests/test_cbook.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,7 @@ def __array__(self):
983983
torch_tensor=torch.Tensor(data)
984984

985985
result=cbook._unpack_to_numpy(torch_tensor)
986+
assertisinstance(result,np.ndarray)
986987
# compare results, do not check for identity: the latter would fail
987988
# if not mocked, and the implementation does not guarantee it
988989
# is the same Python object, just the same values.
@@ -1011,6 +1012,7 @@ def __array__(self):
10111012
jax_array=jax.Array(data)
10121013

10131014
result=cbook._unpack_to_numpy(jax_array)
1015+
assertisinstance(result,np.ndarray)
10141016
# compare results, do not check for identity: the latter would fail
10151017
# if not mocked, and the implementation does not guarantee it
10161018
# is the same Python object, just the same values.
@@ -1040,6 +1042,7 @@ def __array__(self):
10401042
tf_tensor=tensorflow.Tensor(data)
10411043

10421044
result=cbook._unpack_to_numpy(tf_tensor)
1045+
assertisinstance(result,np.ndarray)
10431046
# compare results, do not check for identity: the latter would fail
10441047
# if not mocked, and the implementation does not guarantee it
10451048
# is the same Python object, just the same values.

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp