|
4 | 4 | // code for better UX. |
5 | 5 |
|
6 | 6 | #include<torch/csrc/inductor/aoti_torch/c/shim.h> |
7 | | -#include<torch/csrc/stable/tensor.h> |
8 | 7 |
|
9 | | -#include<optional> |
| 8 | +// Technically, this file doesn't use anything from utils.h, but we |
| 9 | +// need to include it here as the contents of utils.h used to live |
| 10 | +// here and so we need to expose them for backwards compatibility. |
| 11 | +#include<torch/csrc/stable/utils.h> |
10 | 12 |
|
11 | 13 | // use anonymous namespace to avoid collisions between differing |
12 | 14 | // versions of this file that may be included by different sources |
13 | 15 | namespace { |
14 | 16 |
|
15 | | -// ============================================================================= |
16 | | -// helpers for converting between StableIValue and T |
17 | | -// ============================================================================= |
18 | | - |
19 | | -// forward declare so that from/to() calls in detail work |
20 | | -template<typename T> |
21 | | -StableIValuefrom(T val); |
22 | | -template<typename T> |
23 | | -Tto(StableIValue val); |
24 | | - |
25 | | -namespacedetail { |
26 | | - |
27 | | -// ============================================================================= |
28 | | -// FROM CONVERSIONS (T -> StableIValue) |
29 | | -// ============================================================================= |
30 | | - |
31 | | -// Specialization for general copyable types (catch-all) => StableIValue |
32 | | -template<typename T> |
33 | | -structFromImpl { |
34 | | -static StableIValuecall(T val) { |
35 | | -static_assert( |
36 | | -sizeof(T) <=sizeof(StableIValue), |
37 | | -"StableLibrary stack does not support parameter types larger than 64 bits."); |
38 | | -static_assert(std::is_trivially_copyable_v<T>); |
39 | | -// Initialization should be cheap enough; let's give people well-specified |
40 | | -// reproducible behavior. |
41 | | - StableIValue result =0; |
42 | | -// NOTE [ -Wclass-memaccess ]: reinterpret_cast to suppress |
43 | | -// overzealous -Wclass-memaccess. (see |
44 | | -// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=107361) We have a |
45 | | -// static_assert above that T is trivially copyable, which should be |
46 | | -// enough. |
47 | | -#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ |
48 | | -std::memcpy(&result,reinterpret_cast<constvoid*>(&val),sizeof(val)); |
49 | | -#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ |
50 | | -// if value has size less than sizeof(StableIValue), then only lowest bytes |
51 | | -// have to be updated |
52 | | -std::memcpy( |
53 | | -reinterpret_cast<unsignedchar*>(&result) +sizeof(StableIValue) - |
54 | | -sizeof(val), |
55 | | -reinterpret_cast<constvoid*>(&val), |
56 | | -sizeof(val)); |
57 | | -#else |
58 | | -#error Unexpected or undefined __BYTE_ORDER__ |
59 | | -#endif |
60 | | -return result; |
61 | | - } |
62 | | -}; |
63 | | - |
64 | | -// Specialization for std::nullopt_t => StableIValue |
65 | | -template<> |
66 | | -structFromImpl<std::nullopt_t> { |
67 | | -static StableIValuecall(std::nullopt_t val) { |
68 | | -returnfrom(nullptr); |
69 | | - } |
70 | | -}; |
71 | | - |
72 | | -// Specialization for std::optional => StableIValue |
73 | | -// [Handling std::optional] |
74 | | -// When the schema is represented by an optional type, say int?, then we |
75 | | -// expect the custom extension representation to be a std::optional<int> |
76 | | -// (critically NOT int!). In order for all parameters to be stably parsed and |
77 | | -// handled by our dispatcher, we liaison custom extension parameters through |
78 | | -// boxed kernels, meaning that every value will make its way to be an IValue: |
79 | | -// |
80 | | -// custom extension value --(from)-> StableIValue --(to_ivalue)-> IValue |
81 | | -// |
82 | | -// When the custom extension value is a literal that can be trivially |
83 | | -// casted to StableIValue, e.g., an int, a float, a pointer, this route is |
84 | | -// ...trivial. The below specialization is for a case when the custom |
85 | | -// extension value would NOT fit within a StableIValue: a std::optional. |
86 | | -// |
87 | | -// If the std::optional has no value, it is treated as std::nullopt, |
88 | | -// whose StableIValue representation is from(nullptr). Otherwise, we: |
89 | | -// 1. unwrap the std::optional<T> |
90 | | -// 2. recursively convert its value of type T to a StableIValue |
91 | | -// 3. allocate heap space for said StableIValue |
92 | | -// 4. convert the resulting StableIValue* into a StableIValue |
93 | | -// |
94 | | -// note that this allocates heap memory! which we expect to be cleaned |
95 | | -// up in the to_ivalue() function defined in shim_common.cpp. We |
96 | | -// purposefully hide this implementation detail from the user so that |
97 | | -// all the user needs to know is: |
98 | | -// |
99 | | -// The schema requests an optional (T?) so I must call `from` on a |
100 | | -// std::optional<T> or a std::nullopt. |
101 | | -template<typename T> |
102 | | -structFromImpl<std::optional<T>> { |
103 | | -static StableIValuecall(const std::optional<T>& val) { |
104 | | -if (!val.has_value()) { |
105 | | -returnfrom(std::nullopt); |
106 | | - } |
107 | | - StableIValue* heap_val =newStableIValue(from(val.value())); |
108 | | -returnfrom(heap_val); |
109 | | - } |
110 | | -}; |
111 | | - |
112 | | -// Specialization for torch::stable::Tensor => StableIValue |
113 | | -// Returns a new owning reference of the underlying Tensor. |
114 | | -template<> |
115 | | -structFromImpl<torch::stable::Tensor> { |
116 | | -static StableIValuecall(const torch::stable::Tensor& val) { |
117 | | - AtenTensorHandle new_ath; |
118 | | -aoti_torch_new_tensor_handle(val.get(), &new_ath); |
119 | | -returnfrom(new_ath); |
120 | | - } |
121 | | -}; |
122 | | - |
123 | | -// ============================================================================= |
124 | | -// TO CONVERSIONS (StableIValue -> T) |
125 | | -// ============================================================================= |
126 | | - |
127 | | -// Specialization for StableIValue => general copyable types (catch-all) |
128 | | -template<typename T> |
129 | | -structToImpl { |
130 | | -static Tcall(StableIValue val) { |
131 | | -static_assert(std::is_trivially_copyable_v<T>); |
132 | | -// T may not have a default constructor. (For example, it might be |
133 | | -// c10::Device.) However, std::memcpy implicitly creates a T at the |
134 | | -// destination. So, we can use a union to work around this lack of |
135 | | -// default constructor. |
136 | | -union Result { |
137 | | -Result() {} |
138 | | - T t; |
139 | | - }; |
140 | | - Result result; |
141 | | -// See NOTE[ -Wclass-memaccess ] above. |
142 | | -#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ |
143 | | -std::memcpy(reinterpret_cast<void*>(&result.t), &val,sizeof(result)); |
144 | | -#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ |
145 | | -static_assert( |
146 | | -sizeof(T) <=sizeof(StableIValue), |
147 | | -"StableLibrary stack does not support parameter types larger than 64 bits."); |
148 | | -// if value has size less than sizeof(StableIValue), then only lowest bytes |
149 | | -// have to be updated |
150 | | -std::memcpy( |
151 | | -reinterpret_cast<void*>(&result.t), |
152 | | -reinterpret_cast<unsignedchar*>(&val) +sizeof(StableIValue) - |
153 | | -sizeof(result), |
154 | | -sizeof(result)); |
155 | | -#else |
156 | | -#error Unexpected or undefined __BYTE_ORDER__ |
157 | | -#endif |
158 | | -return result.t; |
159 | | - } |
160 | | -}; |
161 | | - |
162 | | -// Specialization for StableIValue => std::nullopt_t |
163 | | -template<> |
164 | | -structToImpl<std::nullopt_t> { |
165 | | -static std::nullopt_tcall(StableIValue val) { |
166 | | -// val should be equivalent to from(nullptr) |
167 | | -return std::nullopt; |
168 | | - } |
169 | | -}; |
170 | | - |
171 | | -// Specialization for StableIValue => std::optional, see [Handling |
172 | | -// std::optional] as the semantic is the same but in reverse direction as we go |
173 | | -// from IValue --(from_ivalue)-> StableIValue --(to<T>)-> T in custom extension |
174 | | -template<typename T> |
175 | | -structToImpl<std::optional<T>> { |
176 | | -static std::optional<T>call(StableIValue val) { |
177 | | -auto sivp = to<StableIValue*>(val); |
178 | | - |
179 | | -// sivp is either nullptr or a pointer to a StableIValue |
180 | | -if (sivp ==nullptr) { |
181 | | -return {}; |
182 | | - } |
183 | | -auto inner_val = to<T>(*sivp); |
184 | | - |
185 | | -// free the memory associated with StableIValue* sivp |
186 | | -delete sivp; |
187 | | - |
188 | | -returnstd::make_optional(inner_val); |
189 | | - } |
190 | | -}; |
191 | | - |
192 | | -// Specialization for StableIValue => torch::stable::Tensor |
193 | | -// The resulting stable::Tensor steals ownership of the input's |
194 | | -// underlying AtenTensorHandle. |
195 | | -template<> |
196 | | -structToImpl<torch::stable::Tensor> { |
197 | | -static torch::stable::Tensorcall(StableIValue val) { |
198 | | -returntorch::stable::Tensor(to<AtenTensorHandle>(val)); |
199 | | - } |
200 | | -}; |
201 | | - |
202 | | -}// namespace detail |
203 | | - |
204 | | -// Expose the partially templated class functions through single functions |
205 | | -template<typename T> |
206 | | -StableIValuefrom(T val) { |
207 | | -return detail::FromImpl<T>::call(val); |
208 | | -} |
209 | | - |
210 | | -template<typename T> |
211 | | -StableIValuefrom(const std::optional<T>& val) { |
212 | | -return detail::FromImpl<std::optional<T>>::call(val); |
213 | | -} |
214 | | - |
215 | | -// The below overload is used! See https://godbolt.org/z/859cshxrW |
216 | | -// We are suppressing the warning for versions clang12- and gcc11- |
217 | | -[[maybe_unused]] StableIValuefrom(const torch::stable::Tensor& val) { |
218 | | -return detail::FromImpl<torch::stable::Tensor>::call(val); |
219 | | -} |
220 | | - |
221 | | -template<typename T> |
222 | | -Tto(StableIValue val) { |
223 | | -return detail::ToImpl<T>::call(val); |
224 | | -} |
225 | | - |
226 | | -// ============================================================================= |
227 | | -// end to helpers for converting between StableIValue and T |
228 | | -// ============================================================================= |
229 | | - |
230 | 17 | classStableLibraryfinal { |
231 | 18 | private: |
232 | 19 | TorchLibraryHandle lib_; |
|