Uh oh!
There was an error while loading.Please reload this page.
- Notifications
You must be signed in to change notification settings - Fork366
jax and bfloat16 after #2874#3255
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
-
Hi - as of#2874 / v3.1.0, is there a canonical way to implement/support bfloat16? I found thisproof-of-concept for adding new dtypes, but it looks like the interface has changed a bit since that proposal. Curious if anyone's thought about this yet. Schemed this up by more-or-less copying the example from the docs: """Zarr-compatible bf16"""fromtypingimportClassVar,Literal,Self,TypeGuard,overloadimportml_dtypesimportnumpyasnpfromzarr.core.commonimportJSON,ZarrFormatfromzarr.core.dtypeimportZDType,data_type_registryfromzarr.core.dtype.commonimport (DataTypeValidationError,DTypeConfig_V2,DTypeJSON,check_dtype_spec_v2,)np.sctypeDict["bfloat16"]=ml_dtypes.bfloat16bf16_dtype_cls=type(np.dtype("bfloat16"))bf16_scalar_cls=ml_dtypes.bfloat16classBf16(ZDType[bf16_dtype_cls,bf16_scalar_cls]):"""Zarr-wrapper around bfloat16"""# This field is as the key for the data type in the internal data type registry, and also# as the identifier for the data type when serializaing the data type to disk for zarr v3_zarr_v3_name:ClassVar[Literal["bfloat16"]]="bfloat16"# this field will be used internally_zarr_v2_name:ClassVar[Literal["bfloat16"]]="bfloat16"# we bind a class variable to the native data type class so we can create instances of itdtype_cls=bf16_dtype_cls@classmethoddeffrom_native_dtype(cls,dtype:np.dtype)->Self:"""Create an instance of this ZDType from a native dtype."""ifcls._check_native_dtype(dtype):returncls()raiseDataTypeValidationError(f"Invalid data type:{dtype}. Expected an instance of{cls.dtype_cls}" )defto_native_dtype(self:Self)->bf16_dtype_cls:"""Create an bfloat16 dtype instance from this ZDType"""returnself.dtype_cls()@classmethoddef_check_json_v2(cls,data:DTypeJSON):""" Type check for Zarr v2-flavored JSON. This will check that the input is a dict like this: .. code-block:: json { "name": "bfloat16", "object_codec_id": None } Note that this representation differs from the ``dtype`` field looks like in zarr v2 metadata. Specifically, whatever goes into the ``dtype`` field in metadata is assigned to the ``name`` field here. See the Zarr docs for more information about the JSON encoding for data types. """return (check_dtype_spec_v2(data)anddata["name"]=="bfloat16"anddata["object_codec_id"]isNone )@classmethoddef_check_json_v3(cls,data:DTypeJSON):""" Type check for Zarr V3-flavored JSON. Checks that the input is the string "bfloat16". """returndata==cls._zarr_v3_name@classmethoddef_from_json_v2(cls,data:DTypeJSON)->Self:""" Create an instance of this ZDType from Zarr V3-flavored JSON. """ifcls._check_json_v2(data):returncls()# This first does a type check on the input, and if that passes we create an instance of the ZDType.msg=f"Invalid JSON representation of{cls.__name__}. Got{data!r}, expected the string{cls._zarr_v2_name!r}"raiseDataTypeValidationError(msg)@classmethoddef_from_json_v3(cls:type[Self],data:DTypeJSON)->Self:""" Create an instance of this ZDType from Zarr V3-flavored JSON. This first does a type check on the input, and if that passes we create an instance of the ZDType. """ifcls._check_json_v3(data):returncls()msg=f"Invalid JSON representation of{cls.__name__}. Got{data!r}, expected the string{cls._zarr_v3_name!r}"raiseDataTypeValidationError(msg)@overload# type: ignore[override]defto_json(self,zarr_format:Literal[2] )->DTypeConfig_V2[Literal["bfloat16"],None]: ...@overloaddefto_json(self,zarr_format:Literal[3])->Literal["bfloat16"]: ...defto_json(self,zarr_format:ZarrFormat )->DTypeConfig_V2[Literal["bfloat16"],None]|Literal["bfloat16"]:""" Serialize this ZDType to v2- or v3-flavored JSON If the zarr_format is 2, then return a dict like this: .. code-block:: json { "name": "bfloat16", "object_codec_id": None } If the zarr_format is 3, then return the string "bfloat16" """ifzarr_format==2:return {"name":"bfloat16","object_codec_id":None}elifzarr_format==3:returnself._zarr_v3_nameraiseValueError(f"zarr_format must be 2 or 3, got{zarr_format}" )# pragma: no coverdef_check_scalar(self,data:object)->TypeGuard[int|ml_dtypes.bfloat16]:""" Check if a python object is a valid bfloat16-compatible scalar """returnisinstance(data, (float,bf16_scalar_cls))defcast_scalar(self,data:object)->ml_dtypes.bfloat16:""" Attempt to cast a python object to an bfloat16. We first perform a type check to ensure that the input type is appropriate, and if that passes we call the bfloat16 scalar constructor. """ifself._check_scalar(data):returnml_dtypes.bfloat16(data)msg= (f"Cannot convert object{data!r} with type{type(data)} to a scalar compatible with the "f"data type{self}." )raiseTypeError(msg)defdefault_scalar(self)->ml_dtypes.bfloat16:""" Get the default scalar value. This will be used when automatically selecting a fill value. """returnml_dtypes.bfloat16(0)defto_json_scalar(self,data:object,*,zarr_format:ZarrFormat)->float:""" Convert a python object to a JSON representation of an float scalar. This is necessary for taking user input for the ``fill_value`` attribute in array metadata. In this implementation, we optimistically convert the input to an int, and then check that it lies in the acceptable range for this data type. """# We could add a type check here, but we don't need to for this examplereturnfloat(data)deffrom_json_scalar(self,data:JSON,*,zarr_format:ZarrFormat )->ml_dtypes.bfloat16:""" Read a JSON-serializable value as a bfloat16 scalar. We first perform a type check to ensure that the JSON value is well-formed, then call the bfloat16 scalar constructor. The base definition of this method requires that it take a zarr_format parameter because other data types serialize scalars differently in zarr v2 and v3, but we don't use this here. """ifself._check_scalar(data):returnml_dtypes.bfloat16(data)raiseTypeError(f"Invalid type:{data}. Expected a float.")data_type_registry.register(Bf16._zarr_v3_name,Bf16) so far it seems to be working okay, but would love to get some input in case i'm doing something dumb! |
BetaWas this translation helpful?Give feedback.
All reactions
Replies: 1 comment
-
@jacknewsom if you copiedthis example, and it's working, then that's great! Let us know if you run into any problems storing data. If you plan to share your data with other people, you should probably ensure that you use a common JSON representation for bfloat16. According to thisspec document, the JSON form should just be the string @jbms does tensorstore support zarr + bfloat16? If so, compatibility with tensorstore would be a good external test case. And if this data type is popular enough, we could consider defining it in zarr python or an external package under an optional dependency. |
BetaWas this translation helpful?Give feedback.