- Notifications
You must be signed in to change notification settings - Fork54
How to discover an array type's namespace from an input type annotation#948
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
-
In machine learning workflow engines, I've come across a lot of code similar to this (with JAX as the array API implementer of choice in this example): importjaximportjax.numpyasjnp@taskdefzeros()->jax.Array:returnjnp.zeros(5)@taskdefcompute_sum(arr:jax.Array)->float:returnjnp.sum(arr).item()@workflowdefbasic_workflow():arr=zeros()compute_sum(arr)if__name__=="__main__":basic_workflow() where the array I/O needed for the output of A lot of workflow libraries allow the user to customize said I/O, perhaps with an interface like the following: classMyJAXArrayIO:defload(val:Any,typ:jax.Array)->jax.Array: ..."""Instantiate a loaded value as a JAX array."""defstore(val:jax.Array)->None: ..."""Write a JAX array, e.g. to disk.""" which would be called to instantiate all of the arrays that come out of I think the array API standard would be a good fit to generalize such an I/O machinery to handle any array API standard implementing type, since it could provide a very basic load mechanism for arrays like this: classMyArrayStandardIO(Generic[T]):defload(val:Any,typ:T)->T:"""Instantiate a loaded value as a array-API-conforming array."""# step 1: Look up the type's array API namespace...namespace=typ.__array_namespace__()# ... and return the value as an array in that namespace.returnnamespace.asarray(val) Unfortunately, it seems that >>>importjax>>>jax.Array.__array_namespace__()Traceback (mostrecentcalllast):File"<stdin>",line1,in<module>File"/Users/nicholasjunge/.../site-packages/jax/_src/numpy/array_methods.py",line1133,inmethodraiseNotImplementedError(f"Cannot call abstract method{name}")NotImplementedError:Cannotcallabstractmethod__array_namespace__ NumPy produces a different error, but I'm also unable to obtain the namespace there. As an added difficulty, JAX's array namespace is Is there any standards-conforming way I can deduce the array namespace from an arraytype annotation only? |
BetaWas this translation helpful?Give feedback.