Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

How to discover an array type's namespace from an input type annotation#948

Unanswered
nicholasjng asked this question inQ&A
Discussion options

In machine learning workflow engines, I've come across a lot of code similar to this (with JAX as the array API implementer of choice in this example):

importjaximportjax.numpyasjnp@taskdefzeros()->jax.Array:returnjnp.zeros(5)@taskdefcompute_sum(arr:jax.Array)->float:returnjnp.sum(arr).item()@workflowdefbasic_workflow():arr=zeros()compute_sum(arr)if__name__=="__main__":basic_workflow()

where the array I/O needed for the output ofzeros and the input ofcompute_sum is abstracted away from the user in the@task decorator.

A lot of workflow libraries allow the user to customize said I/O, perhaps with an interface like the following:

classMyJAXArrayIO:defload(val:Any,typ:jax.Array)->jax.Array: ..."""Instantiate a loaded value as a JAX array."""defstore(val:jax.Array)->None: ..."""Write a JAX array, e.g. to disk."""

which would be called to instantiate all of the arrays that come out ofzeros and go intocompute_sum, respectively. These happen to be the same in this example.

I think the array API standard would be a good fit to generalize such an I/O machinery to handle any array API standard implementing type, since it could provide a very basic load mechanism for arrays like this:

classMyArrayStandardIO(Generic[T]):defload(val:Any,typ:T)->T:"""Instantiate a loaded value as a array-API-conforming array."""# step 1: Look up the type's array API namespace...namespace=typ.__array_namespace__()# ... and return the value as an array in that namespace.returnnamespace.asarray(val)

Unfortunately, it seems that__array_namespace__ is an instance method only, so I cannot call it on the array type itself. Indeed, in the example of JAX, I get the following:

>>>importjax>>>jax.Array.__array_namespace__()Traceback (mostrecentcalllast):File"<stdin>",line1,in<module>File"/Users/nicholasjunge/.../site-packages/jax/_src/numpy/array_methods.py",line1133,inmethodraiseNotImplementedError(f"Cannot call abstract method{name}")NotImplementedError:Cannotcallabstractmethod__array_namespace__

NumPy produces a different error, but I'm also unable to obtain the namespace there. As an added difficulty, JAX's array namespace isjax.numpy, butjax.Array.__module__ == 'jax'.

Is there any standards-conforming way I can deduce the array namespace from an arraytype annotation only?

You must be logged in to vote

Replies: 0 comments

Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment
Category
Q&A
Labels
None yet
1 participant
@nicholasjng

[8]ページ先頭

©2009-2025 Movatter.jp