Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

jax and bfloat16 after #2874#3255

johnnewsom started this conversation inGeneral
Discussion options

Hi - as of#2874 / v3.1.0, is there a canonical way to implement/support bfloat16? I found thisproof-of-concept for adding new dtypes, but it looks like the interface has changed a bit since that proposal. Curious if anyone's thought about this yet. Schemed this up by more-or-less copying the example from the docs:

"""Zarr-compatible bf16"""fromtypingimportClassVar,Literal,Self,TypeGuard,overloadimportml_dtypesimportnumpyasnpfromzarr.core.commonimportJSON,ZarrFormatfromzarr.core.dtypeimportZDType,data_type_registryfromzarr.core.dtype.commonimport (DataTypeValidationError,DTypeConfig_V2,DTypeJSON,check_dtype_spec_v2,)np.sctypeDict["bfloat16"]=ml_dtypes.bfloat16bf16_dtype_cls=type(np.dtype("bfloat16"))bf16_scalar_cls=ml_dtypes.bfloat16classBf16(ZDType[bf16_dtype_cls,bf16_scalar_cls]):"""Zarr-wrapper around bfloat16"""# This field is as the key for the data type in the internal data type registry, and also# as the identifier for the data type when serializaing the data type to disk for zarr v3_zarr_v3_name:ClassVar[Literal["bfloat16"]]="bfloat16"# this field will be used internally_zarr_v2_name:ClassVar[Literal["bfloat16"]]="bfloat16"# we bind a class variable to the native data type class so we can create instances of itdtype_cls=bf16_dtype_cls@classmethoddeffrom_native_dtype(cls,dtype:np.dtype)->Self:"""Create an instance of this ZDType from a native dtype."""ifcls._check_native_dtype(dtype):returncls()raiseDataTypeValidationError(f"Invalid data type:{dtype}. Expected an instance of{cls.dtype_cls}"        )defto_native_dtype(self:Self)->bf16_dtype_cls:"""Create an bfloat16 dtype instance from this ZDType"""returnself.dtype_cls()@classmethoddef_check_json_v2(cls,data:DTypeJSON):"""        Type check for Zarr v2-flavored JSON.        This will check that the input is a dict like this:        .. code-block:: json        {            "name": "bfloat16",            "object_codec_id": None        }        Note that this representation differs from the ``dtype`` field looks like in zarr v2 metadata.        Specifically, whatever goes into the ``dtype`` field in metadata is assigned to the ``name`` field here.        See the Zarr docs for more information about the JSON encoding for data types.        """return (check_dtype_spec_v2(data)anddata["name"]=="bfloat16"anddata["object_codec_id"]isNone        )@classmethoddef_check_json_v3(cls,data:DTypeJSON):"""        Type check for Zarr V3-flavored JSON.        Checks that the input is the string "bfloat16".        """returndata==cls._zarr_v3_name@classmethoddef_from_json_v2(cls,data:DTypeJSON)->Self:"""        Create an instance of this ZDType from Zarr V3-flavored JSON.        """ifcls._check_json_v2(data):returncls()#  This first does a type check on the input, and if that passes we create an instance of the ZDType.msg=f"Invalid JSON representation of{cls.__name__}. Got{data!r}, expected the string{cls._zarr_v2_name!r}"raiseDataTypeValidationError(msg)@classmethoddef_from_json_v3(cls:type[Self],data:DTypeJSON)->Self:"""        Create an instance of this ZDType from Zarr V3-flavored JSON.        This first does a type check on the input, and if that passes we create an instance of the ZDType.        """ifcls._check_json_v3(data):returncls()msg=f"Invalid JSON representation of{cls.__name__}. Got{data!r}, expected the string{cls._zarr_v3_name!r}"raiseDataTypeValidationError(msg)@overload# type: ignore[override]defto_json(self,zarr_format:Literal[2]    )->DTypeConfig_V2[Literal["bfloat16"],None]: ...@overloaddefto_json(self,zarr_format:Literal[3])->Literal["bfloat16"]: ...defto_json(self,zarr_format:ZarrFormat    )->DTypeConfig_V2[Literal["bfloat16"],None]|Literal["bfloat16"]:"""        Serialize this ZDType to v2- or v3-flavored JSON        If the zarr_format is 2, then return a dict like this:        .. code-block:: json        {            "name": "bfloat16",            "object_codec_id": None        }        If the zarr_format is 3, then return the string "bfloat16"        """ifzarr_format==2:return {"name":"bfloat16","object_codec_id":None}elifzarr_format==3:returnself._zarr_v3_nameraiseValueError(f"zarr_format must be 2 or 3, got{zarr_format}"        )# pragma: no coverdef_check_scalar(self,data:object)->TypeGuard[int|ml_dtypes.bfloat16]:"""        Check if a python object is a valid bfloat16-compatible scalar        """returnisinstance(data, (float,bf16_scalar_cls))defcast_scalar(self,data:object)->ml_dtypes.bfloat16:"""        Attempt to cast a python object to an bfloat16.        We first perform a type check to ensure that the input type is appropriate, and if that        passes we call the bfloat16 scalar constructor.        """ifself._check_scalar(data):returnml_dtypes.bfloat16(data)msg= (f"Cannot convert object{data!r} with type{type(data)} to a scalar compatible with the "f"data type{self}."        )raiseTypeError(msg)defdefault_scalar(self)->ml_dtypes.bfloat16:"""        Get the default scalar value. This will be used when automatically selecting a fill value.        """returnml_dtypes.bfloat16(0)defto_json_scalar(self,data:object,*,zarr_format:ZarrFormat)->float:"""        Convert a python object to a JSON representation of an float scalar.        This is necessary for taking user input for the ``fill_value`` attribute in array metadata.        In this implementation, we optimistically convert the input to an int,        and then check that it lies in the acceptable range for this data type.        """# We could add a type check here, but we don't need to for this examplereturnfloat(data)deffrom_json_scalar(self,data:JSON,*,zarr_format:ZarrFormat    )->ml_dtypes.bfloat16:"""        Read a JSON-serializable value as a bfloat16 scalar.        We first perform a type check to ensure that the JSON value is well-formed, then call the        bfloat16 scalar constructor.        The base definition of this method requires that it take a zarr_format parameter because        other data types serialize scalars differently in zarr v2 and v3, but we don't use this here.        """ifself._check_scalar(data):returnml_dtypes.bfloat16(data)raiseTypeError(f"Invalid type:{data}. Expected a float.")data_type_registry.register(Bf16._zarr_v3_name,Bf16)

so far it seems to be working okay, but would love to get some input in case i'm doing something dumb!

You must be logged in to vote

Replies: 1 comment

Comment options

@jacknewsom if you copiedthis example, and it's working, then that's great! Let us know if you run into any problems storing data.

If you plan to share your data with other people, you should probably ensure that you use a common JSON representation for bfloat16. According to thisspec document, the JSON form should just be the string"bfloat16", which is exactly what you are using.

@jbms does tensorstore support zarr + bfloat16? If so, compatibility with tensorstore would be a good external test case.

And if this data type is popular enough, we could consider defining it in zarr python or an external package under an optional dependency.

You must be logged in to vote
0 replies
Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment
Category
General
Labels
None yet
2 participants
@johnnewsom@d-v-b

[8]ページ先頭

©2009-2025 Movatter.jp