Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit8b27941

Browse files
committed
Add ScalarType -> shim conversion, add stable::Tensor.scalar_type
ghstack-source-id:d51cb60Pull Requestresolved:#160557
1 parentdb0b7f1 commit8b27941

File tree

8 files changed

+534
-378
lines changed

8 files changed

+534
-378
lines changed

‎test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp‎

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,10 @@ Tensor my_ones_like(Tensor t, StableIValue device) {
134134
constauto num_args =6;
135135
StableIValue stack[num_args];
136136

137-
int32_t t_dtype;
138-
aoti_torch_get_dtype(t.get(), &t_dtype);
139137
auto mf =aoti_torch_memory_format_contiguous_format();
140138

141139
stack[0] =from(t);
142-
stack[1] =from(std::optional(t_dtype));// dtype
140+
stack[1] =from(std::optional(t.scalar_type()));// dtype
143141
stack[2] =from(std::nullopt);// layout
144142
stack[3] =from(std::optional(device));// device
145143
stack[4] =from(std::optional(false));// pin_memory

‎test/test_cpp_extensions_jit.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1227,7 +1227,7 @@ def test_aoti_torch_call_dispatcher(self):
12271227
#include <torch/csrc/inductor/aoti_runtime/utils.h>
12281228
#include <torch/csrc/inductor/aoti_torch/utils.h>
12291229
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
1230-
#include <torch/csrc/stable/library.h>
1230+
#include <torch/csrc/stable/utils.h>
12311231
12321232
using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle;
12331233

‎torch/csrc/stable/library.h‎

Lines changed: 4 additions & 217 deletions
Original file line numberDiff line numberDiff line change
@@ -4,229 +4,16 @@
44
// code for better UX.
55

66
#include<torch/csrc/inductor/aoti_torch/c/shim.h>
7-
#include<torch/csrc/stable/tensor.h>
87

9-
#include<optional>
8+
// Technically, this file doesn't use anything from utils.h, but we
9+
// need to include it here as the contents of utils.h used to live
10+
// here and so we need to expose them for backwards compatibility.
11+
#include<torch/csrc/stable/utils.h>
1012

1113
// use anonymous namespace to avoid collisions between differing
1214
// versions of this file that may be included by different sources
1315
namespace {
1416

15-
// =============================================================================
16-
// helpers for converting between StableIValue and T
17-
// =============================================================================
18-
19-
// forward declare so that from/to() calls in detail work
20-
template<typename T>
21-
StableIValuefrom(T val);
22-
template<typename T>
23-
Tto(StableIValue val);
24-
25-
namespacedetail {
26-
27-
// =============================================================================
28-
// FROM CONVERSIONS (T -> StableIValue)
29-
// =============================================================================
30-
31-
// Specialization for general copyable types (catch-all) => StableIValue
32-
template<typename T>
33-
structFromImpl {
34-
static StableIValuecall(T val) {
35-
static_assert(
36-
sizeof(T) <=sizeof(StableIValue),
37-
"StableLibrary stack does not support parameter types larger than 64 bits.");
38-
static_assert(std::is_trivially_copyable_v<T>);
39-
// Initialization should be cheap enough; let's give people well-specified
40-
// reproducible behavior.
41-
StableIValue result =0;
42-
// NOTE [ -Wclass-memaccess ]: reinterpret_cast to suppress
43-
// overzealous -Wclass-memaccess. (see
44-
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=107361) We have a
45-
// static_assert above that T is trivially copyable, which should be
46-
// enough.
47-
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
48-
std::memcpy(&result,reinterpret_cast<constvoid*>(&val),sizeof(val));
49-
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
50-
// if value has size less than sizeof(StableIValue), then only lowest bytes
51-
// have to be updated
52-
std::memcpy(
53-
reinterpret_cast<unsignedchar*>(&result) +sizeof(StableIValue) -
54-
sizeof(val),
55-
reinterpret_cast<constvoid*>(&val),
56-
sizeof(val));
57-
#else
58-
#error Unexpected or undefined __BYTE_ORDER__
59-
#endif
60-
return result;
61-
}
62-
};
63-
64-
// Specialization for std::nullopt_t => StableIValue
65-
template<>
66-
structFromImpl<std::nullopt_t> {
67-
static StableIValuecall(std::nullopt_t val) {
68-
returnfrom(nullptr);
69-
}
70-
};
71-
72-
// Specialization for std::optional => StableIValue
73-
// [Handling std::optional]
74-
// When the schema is represented by an optional type, say int?, then we
75-
// expect the custom extension representation to be a std::optional<int>
76-
// (critically NOT int!). In order for all parameters to be stably parsed and
77-
// handled by our dispatcher, we liaison custom extension parameters through
78-
// boxed kernels, meaning that every value will make its way to be an IValue:
79-
//
80-
// custom extension value --(from)-> StableIValue --(to_ivalue)-> IValue
81-
//
82-
// When the custom extension value is a literal that can be trivially
83-
// casted to StableIValue, e.g., an int, a float, a pointer, this route is
84-
// ...trivial. The below specialization is for a case when the custom
85-
// extension value would NOT fit within a StableIValue: a std::optional.
86-
//
87-
// If the std::optional has no value, it is treated as std::nullopt,
88-
// whose StableIValue representation is from(nullptr). Otherwise, we:
89-
// 1. unwrap the std::optional<T>
90-
// 2. recursively convert its value of type T to a StableIValue
91-
// 3. allocate heap space for said StableIValue
92-
// 4. convert the resulting StableIValue* into a StableIValue
93-
//
94-
// note that this allocates heap memory! which we expect to be cleaned
95-
// up in the to_ivalue() function defined in shim_common.cpp. We
96-
// purposefully hide this implementation detail from the user so that
97-
// all the user needs to know is:
98-
//
99-
// The schema requests an optional (T?) so I must call `from` on a
100-
// std::optional<T> or a std::nullopt.
101-
template<typename T>
102-
structFromImpl<std::optional<T>> {
103-
static StableIValuecall(const std::optional<T>& val) {
104-
if (!val.has_value()) {
105-
returnfrom(std::nullopt);
106-
}
107-
StableIValue* heap_val =newStableIValue(from(val.value()));
108-
returnfrom(heap_val);
109-
}
110-
};
111-
112-
// Specialization for torch::stable::Tensor => StableIValue
113-
// Returns a new owning reference of the underlying Tensor.
114-
template<>
115-
structFromImpl<torch::stable::Tensor> {
116-
static StableIValuecall(const torch::stable::Tensor& val) {
117-
AtenTensorHandle new_ath;
118-
aoti_torch_new_tensor_handle(val.get(), &new_ath);
119-
returnfrom(new_ath);
120-
}
121-
};
122-
123-
// =============================================================================
124-
// TO CONVERSIONS (StableIValue -> T)
125-
// =============================================================================
126-
127-
// Specialization for StableIValue => general copyable types (catch-all)
128-
template<typename T>
129-
structToImpl {
130-
static Tcall(StableIValue val) {
131-
static_assert(std::is_trivially_copyable_v<T>);
132-
// T may not have a default constructor. (For example, it might be
133-
// c10::Device.) However, std::memcpy implicitly creates a T at the
134-
// destination. So, we can use a union to work around this lack of
135-
// default constructor.
136-
union Result {
137-
Result() {}
138-
T t;
139-
};
140-
Result result;
141-
// See NOTE[ -Wclass-memaccess ] above.
142-
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
143-
std::memcpy(reinterpret_cast<void*>(&result.t), &val,sizeof(result));
144-
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
145-
static_assert(
146-
sizeof(T) <=sizeof(StableIValue),
147-
"StableLibrary stack does not support parameter types larger than 64 bits.");
148-
// if value has size less than sizeof(StableIValue), then only lowest bytes
149-
// have to be updated
150-
std::memcpy(
151-
reinterpret_cast<void*>(&result.t),
152-
reinterpret_cast<unsignedchar*>(&val) +sizeof(StableIValue) -
153-
sizeof(result),
154-
sizeof(result));
155-
#else
156-
#error Unexpected or undefined __BYTE_ORDER__
157-
#endif
158-
return result.t;
159-
}
160-
};
161-
162-
// Specialization for StableIValue => std::nullopt_t
163-
template<>
164-
structToImpl<std::nullopt_t> {
165-
static std::nullopt_tcall(StableIValue val) {
166-
// val should be equivalent to from(nullptr)
167-
return std::nullopt;
168-
}
169-
};
170-
171-
// Specialization for StableIValue => std::optional, see [Handling
172-
// std::optional] as the semantic is the same but in reverse direction as we go
173-
// from IValue --(from_ivalue)-> StableIValue --(to<T>)-> T in custom extension
174-
template<typename T>
175-
structToImpl<std::optional<T>> {
176-
static std::optional<T>call(StableIValue val) {
177-
auto sivp = to<StableIValue*>(val);
178-
179-
// sivp is either nullptr or a pointer to a StableIValue
180-
if (sivp ==nullptr) {
181-
return {};
182-
}
183-
auto inner_val = to<T>(*sivp);
184-
185-
// free the memory associated with StableIValue* sivp
186-
delete sivp;
187-
188-
returnstd::make_optional(inner_val);
189-
}
190-
};
191-
192-
// Specialization for StableIValue => torch::stable::Tensor
193-
// The resulting stable::Tensor steals ownership of the input's
194-
// underlying AtenTensorHandle.
195-
template<>
196-
structToImpl<torch::stable::Tensor> {
197-
static torch::stable::Tensorcall(StableIValue val) {
198-
returntorch::stable::Tensor(to<AtenTensorHandle>(val));
199-
}
200-
};
201-
202-
}// namespace detail
203-
204-
// Expose the partially templated class functions through single functions
205-
template<typename T>
206-
StableIValuefrom(T val) {
207-
return detail::FromImpl<T>::call(val);
208-
}
209-
210-
template<typename T>
211-
StableIValuefrom(const std::optional<T>& val) {
212-
return detail::FromImpl<std::optional<T>>::call(val);
213-
}
214-
215-
// The below overload is used! See https://godbolt.org/z/859cshxrW
216-
// We are suppressing the warning for versions clang12- and gcc11-
217-
[[maybe_unused]] StableIValuefrom(const torch::stable::Tensor& val) {
218-
return detail::FromImpl<torch::stable::Tensor>::call(val);
219-
}
220-
221-
template<typename T>
222-
Tto(StableIValue val) {
223-
return detail::ToImpl<T>::call(val);
224-
}
225-
226-
// =============================================================================
227-
// end to helpers for converting between StableIValue and T
228-
// =============================================================================
229-
23017
classStableLibraryfinal {
23118
private:
23219
TorchLibraryHandle lib_;

‎torch/csrc/stable/ops.h‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include<torch/csrc/stable/library.h>
3+
#include<torch/csrc/stable/utils.h>
44
#include<array>
55
#include<cstdint>
66
#include<optional>

‎torch/csrc/stable/tensor-inl.h‎

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#pragma once
2+
3+
// This file implements tensor.h. We separated out the Tensor struct so that
4+
// other files can depend on the Tensor struct (like library.h) and the
5+
// implementations of the Tensor methods can depend on APIs in library.h
6+
// without circular dependencies.
7+
8+
#pragma once
9+
#include<torch/csrc/stable/tensor.h>
10+
#include<torch/csrc/stable/utils.h>
11+
#include<torch/headeronly/core/ScalarType.h>
12+
#include<torch/headeronly/util/shim_utils.h>
13+
14+
namespacetorch::stable {
15+
16+
using torch::headeronly::ScalarType;
17+
18+
ScalarTypeTensor::scalar_type()const {
19+
int32_t dtype;
20+
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(ath_.get(), &dtype));
21+
return to<ScalarType>(from(dtype));
22+
}
23+
24+
}// namespace torch::stable

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp