- Notifications
You must be signed in to change notification settings - Fork26.3k
Add ScalarType -> shim conversion, add stable::Tensor.scalar_type#160557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Uh oh!
There was an error while loading.Please reload this page.
Conversation
[ghstack-poisoned]
pytorch-botbot commentedAug 13, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
🔗 Helpful Links🧪 See artifacts and rendered test results athud.pytorch.org/pr/160557
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commitf12f9d1 with merge basea44a0d3 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| stack[0] =from(t); | ||
| stack[1] =from(std::optional(t_dtype));// dtype | ||
| stack[1] =from(std::optional(t.scalar_type()));// dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
For testing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
The to/from logic of this file got moved to utils.h
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
This file got split into tensor-struct.h which has everything before the PR, and tensor-inl.h which implements scalar_type as it relies on from/to. The reason for the code split is to allow the code to build without circular dependencies. Without the split, tensor.h would depend on library.h (for to/from) and library.h would depend on tensor.h (cuz to/from Tensor needs a Tensor def).
Now, utils.h (which has to/from) depends on tensor-struct.h, tensor-inl.h depends on both utils.h and tensor-struct.h, and users depend on tensor.h still, which depends on all of the above.
| case ScalarType::UInt64: | ||
| returnfrom(aoti_torch_dtype_uint64()); | ||
| default: | ||
| throwstd::runtime_error( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
NOTE!! THIS IS WHERE I WANT REVIEW!! cc@albanD
Prior, if we had an IValue dtype that was qint8, from_ivalue would callfrom(ScalarType::Qint8), and the code would just reinterpret the enum and spit out the int32_t correspondingly. This was okay because ScalarType wasn't exposed to the end user, and all they had to work with was an abstracted int32_t that they would get from the C shim.
However, with this change today,from(ScalarType::Qint8) would error!!!! Because now, ScalarType is allowed to be used by the end user, and they can call this function, and naively reinterpreting the enum is no longer ok if the extension's ScalarType is different from libtorch's ScalarType! I think erroring is acceptable because these other types are infrequently used by people anyway, but maybe I am wrong about that. e.g.,@swolchok are the Bits ScalarTypes used in ET?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Answering your specific question:
I haven't (yet?) made any attempt to use PyTorch's ScalarType in ExecuTorch. ExecuTorch hashttps://github.com/pytorch/executorch/blob/52b45e2d2ac244b13a36ddf5d21a9ebe8d8aa17e/runtime/core/portable_type/scalar_type.h#L132
PyTorch's ScalarType will get used in ExecuTorch's ATen mode, though.https://github.com/pytorch/executorch/blob/52b45e2d2ac244b13a36ddf5d21a9ebe8d8aa17e/runtime/core/exec_aten/exec_aten.h#L82
I don't know what the Bits ScalarTypes even are, but ExecuTorch seems to have its own versions of them that it uses:https://github.com/search?q=repo%3Apytorch%2Fexecutorch+ScalarType%3A%3ABits+language%3AC%2B%2B&type=code&l=C%2B%2B
In general: it is not backward compatible to change functionality such that a call that previously succeeded (and really did work fine) is now an error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I've concluded it is okay to BC break here given that GitHub search yields 0 users for the narrow use case for which this code would break. Updated the PR body consequently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
this code is a copy pasta EXCEPT for the specializations fortorch::headeronly::ScalarType
…ar_type"This change _modifies_ the from/to behavior between ScalarType and StableValue!Then we add a Tensor scalar_type API which reuses the from/to logic to return to the user a nice ScalarType (vs an abstracted int32_t).I then changed the test to test the scalar_type API.This code change required some refactoring because of circular dependencies.[ghstack-poisoned]
…ar_type"This change _modifies_ the from/to behavior between ScalarType and StableValue!Then we add a Tensor scalar_type API which reuses the from/to logic to return to the user a nice ScalarType (vs an abstracted int32_t).I then changed the test to test the scalar_type API.This code change required some refactoring because of circular dependencies.[ghstack-poisoned]
| @@ -0,0 +1,342 @@ | |||
| #pragma once | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I firmly dislike naming things "utils" because it is synonymous with "stuff" and helps neither predict their current contents nor limit their future contents. Instead I would consider a specific name, like say StableIValueConversions.h .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
renamed!!!!!! stableivalue_conversions.h
| case ScalarType::UInt64: | ||
| returnfrom(aoti_torch_dtype_uint64()); | ||
| default: | ||
| throwstd::runtime_error( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Answering your specific question:
I haven't (yet?) made any attempt to use PyTorch's ScalarType in ExecuTorch. ExecuTorch hashttps://github.com/pytorch/executorch/blob/52b45e2d2ac244b13a36ddf5d21a9ebe8d8aa17e/runtime/core/portable_type/scalar_type.h#L132
PyTorch's ScalarType will get used in ExecuTorch's ATen mode, though.https://github.com/pytorch/executorch/blob/52b45e2d2ac244b13a36ddf5d21a9ebe8d8aa17e/runtime/core/exec_aten/exec_aten.h#L82
I don't know what the Bits ScalarTypes even are, but ExecuTorch seems to have its own versions of them that it uses:https://github.com/search?q=repo%3Apytorch%2Fexecutorch+ScalarType%3A%3ABits+language%3AC%2B%2B&type=code&l=C%2B%2B
In general: it is not backward compatible to change functionality such that a call that previously succeeded (and really did work fine) is now an error.
…ar_type"This change _modifies_ the from/to behavior between ScalarType and StableValue!Then we add a Tensor scalar_type API which reuses the from/to logic to return to the user a nice ScalarType (vs an abstracted int32_t).I then changed the test to test the scalar_type API.This code change required some refactoring because of circular dependencies.[ghstack-poisoned]
…ar_type"This change _modifies_ the from/to behavior between ScalarType and StableValue!Then we add a Tensor scalar_type API which reuses the from/to logic to return to the user a nice ScalarType (vs an abstracted int32_t).I then changed the test to test the scalar_type API.This code change required some refactoring because of circular dependencies.[ghstack-poisoned]
| structFromImpl<ScalarType> { | ||
| static StableIValuecall(ScalarType val) { | ||
| switch (val) { | ||
| case ScalarType::Byte: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
This feels like it can benefit from define of sorts, that iterates over known dtypes (if dtypes are part of stable API...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I was trying to figure out a way, but it did not seem worth it for this PR. (Also there's no "is dtype part of stable API" function we can call yet).
| auto inner_val = to<T>(*sivp); | ||
| // free the memory associated with StableIValue* sivp | ||
| delete sivp; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
This feels very suspicious.. Why not pass the value as unique_ptr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I'm not sure about passing it around as std::unique_ptr, but you could certainly declarestd::unique_ptr<StableIValue> sivp = to<StableIValue*>(val); above and insulate yourself against the innerto<T> throwing an exception.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
will address in a followup
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
| @@ -0,0 +1,342 @@ | |||
| #pragma once | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
nit: is "stableivalue_conversions.h" really the conventional formatting here? I would've expected "StableIValueConversions.h", as with torch/headeronly/core/ScalarType.h
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I think in general it would be good to document the header format, as I've seen bothCamelCase.h,snake_case.h andsomething-weird.h (for exampletensor-inl.h in this PR)
For example, aoti follows snake_case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
idk the conventional format, but I was going under the guise of using Caps when there was a struct definition of the same name, and using lowercase for more thematic naming (like everything in this file relates to ____). And then the -dashes I'm copying from in c10/ Half and Half-inl.h where the dash means it's a continuation of a file, that these files would be together if possible but are broken apart for some other reason.
I'm happy to follow an existing header notation though, if there is one
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
| auto inner_val = to<T>(*sivp); | ||
| // free the memory associated with StableIValue* sivp | ||
| delete sivp; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I'm not sure about passing it around as std::unique_ptr, but you could certainly declarestd::unique_ptr<StableIValue> sivp = to<StableIValue*>(val); above and insulate yourself against the innerto<T> throwing an exception.
| @@ -0,0 +1,342 @@ | |||
| #pragma once | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I think in general it would be good to document the header format, as I've seen bothCamelCase.h,snake_case.h andsomething-weird.h (for exampletensor-inl.h in this PR)
For example, aoti follows snake_case
Uh oh!
There was an error while loading.Please reload this page.
| returnfrom(aoti_torch_dtype_uint8()); | ||
| case ScalarType::Char: | ||
| returnfrom(aoti_torch_dtype_int8()); | ||
| case ScalarType::Short: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Q: Why not add unisgned types here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I was following the convention of the actual enum order
| int32_t shim_scalartype = to<int32_t>(val); | ||
| if (shim_scalartype ==aoti_torch_dtype_uint8()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Is it possible to cast it to some enum and use switch statement? (As it will force devs to add options there when new dtype is added) This statement will be result of never ending series of "Added missing XYZ" here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Well the whole reason to have this is because the user binary ScalarType baked in isn't necessarily the same enum as the libtorch binary ScalarType, which is why I'm passing ints through the shim.
| return ScalarType::Float8_e5m2fnuz; | ||
| }elseif (shim_scalartype ==aoti_torch_dtype_float8_e4m3fnuz()) { | ||
| return ScalarType::Float8_e4m3fnuz; | ||
| }elseif (shim_scalartype ==aoti_torch_dtype_uint16()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Why do you support unisgned dtypes here but not in the previous switch statement?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
They're supported above too
Uh oh!
There was an error while loading.Please reload this page.
| @@ -0,0 +1,24 @@ | |||
| #pragma once | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
See my above comment. Use either CamelCase or snake_case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
okay, ill switch to snake case for these
| void*data_ptr()const { | ||
| void* data_ptr; | ||
| TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(ath_.get(), &data_ptr)); | ||
| return data_ptr; | ||
| } | ||
| int64_tdim()const { | ||
| int64_t dim; | ||
| TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(ath_.get(), &dim)); | ||
| return dim; | ||
| } | ||
| int64_tnumel()const { | ||
| int64_t numel; | ||
| TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(ath_.get(), &numel)); | ||
| return numel; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Please please use macros to get rid of copy_pasta
| void*data_ptr()const { | |
| void*data_ptr; | |
| TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(ath_.get(),&data_ptr)); | |
| returndata_ptr; | |
| } | |
| int64_tdim()const { | |
| int64_tdim; | |
| TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(ath_.get(),&dim)); | |
| returndim; | |
| } | |
| int64_tnumel()const { | |
| int64_tnumel; | |
| TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(ath_.get(),&numel)); | |
| returnnumel; | |
| } | |
| #define_DEF_COSNT_ACCESSOR_METHOD(NAME,DTYPE) \ | |
| DTYPENAME()const { \ | |
| DTYPErc; \ | |
| TORCH_ERROR_CODE_CHECK(aoti_torch_get_##NAME(ath_.get(), &rc)); \ | |
| returnrc; \ | |
| } | |
| _DEF_CONST_ACCESSOR_METHOD(data_ptr,void*); | |
| _DEF_CONST_ACCESSOR_METHOD(dim,int64_t); | |
| _DEF_CONST_ACCESSOR_METHOD(numel,int64_t); | |
| #undef _DEF_CONST_ACCESSOR_METHOD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
It is more readable currently to see these APIs as is. If we need to add more, I will consider using a preprocessor macro.
…ar_type"TL;DR: Moving to ScalarType in user extensions and removing deprecated dtypes.This change _modifies_ the from/to behavior between ScalarType and StableValue! Whereas before, user extensions could only in abstract pass around obfuscated dtypes appearing as int32_ts, now, users can confidently use torch::headeronly::ScalarType in their extensions for major scalar types. This PR enables ABI stability by adding a translation layer through the shim, so that even if the ScalarType enum values change in the future, user extensions need not fear.Then we add a Tensor scalar_type API which reuses the from/to logic to return to the user a nice ScalarType (vs an abstracted int32_t).I then changed the test to test the scalar_type API.This code change required some refactoring because of circular dependencies.## BC Breaking noteThis commit is (narrowly) BC-breaking for unpopular dtypes: `quint*`s, `qint*`s, `Bits*`, `dummy_uint*`s, `dummy_int*`s, `Float8_e8m0fnu`, and `Float4_e2m1fn_x2` in the narrow use case where an extension retrieves a Tensor dtype of the above and passes it into `aoti_torch_call_dispatcher`. As of now, I believe there are 0 users of this use case, so the benefits of this change significantly justify BC-breaking this API.[ghstack-poisoned]
janeyx99 commentedAug 19, 2025
@pytorchbot merge |
pytorchmergebot commentedAug 19, 2025
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in thewiki. Questions? Feedback? Please reach out to thePyTorch DevX Team |
Pull Requestresolved:#159508Approved by:https://github.com/janeyx99ghstack dependencies:#160557
…torch#160557)TL;DR: Moving to ScalarType in user extensions and removing deprecated dtypes.This change _modifies_ the from/to behavior between ScalarType and StableValue! Whereas before, user extensions could only in abstract pass around obfuscated dtypes appearing as int32_ts, now, users can confidently use torch::headeronly::ScalarType in their extensions for major scalar types. This PR enables ABI stability by adding a translation layer through the shim, so that even if the ScalarType enum values change in the future, user extensions need not fear.Then we add a Tensor scalar_type API which reuses the from/to logic to return to the user a nice ScalarType (vs an abstracted int32_t).I then changed the test to test the scalar_type API.This code change required some refactoring because of circular dependencies.## BC Breaking noteThis commit is (narrowly) BC-breaking for unpopular dtypes: `quint*`s, `qint*`s, `Bits*`, `dummy_uint*`s, `dummy_int*`s, `Float8_e8m0fnu`, and `Float4_e2m1fn_x2` in the narrow use case where an extension retrieves a Tensor dtype of the above and passes it into `aoti_torch_call_dispatcher`. As of now, I believe there are 0 users of this use case, so the benefits of this change significantly justify BC-breaking this API.Pull Requestresolved:pytorch#160557Approved by:https://github.com/mikaylagawarecki,https://github.com/malfet
…9508)Pull Requestresolved:pytorch#159508Approved by:https://github.com/janeyx99ghstack dependencies:pytorch#160557
Uh oh!
There was an error while loading.Please reload this page.
TL;DR: Moving to ScalarType in user extensions and removing deprecated dtypes.
This changemodifies the from/to behavior between ScalarType and StableValue! Whereas before, user extensions could only in abstract pass around obfuscated dtypes appearing as int32_ts, now, users can confidently use torch::headeronly::ScalarType in their extensions for major scalar types. This PR enables ABI stability by adding a translation layer through the shim, so that even if the ScalarType enum values change in the future, user extensions need not fear.
Then we add a Tensor scalar_type API which reuses the from/to logic to return to the user a nice ScalarType (vs an abstracted int32_t).
I then changed the test to test the scalar_type API.
This code change required some refactoring because of circular dependencies.
BC Breaking note
This commit is (narrowly) BC-breaking for unpopular dtypes:
quint*s,qint*s,Bits*,dummy_uint*s,dummy_int*s,Float8_e8m0fnu, andFloat4_e2m1fn_x2in the narrow use case where an extension retrieves a Tensor dtype of the above and passes it intoaoti_torch_call_dispatcher. As of now, I believe there are 0 users of this use case, so the benefits of this change significantly justify BC-breaking this API.Stack fromghstack (oldest at bottom):