Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitcfb08ae

Browse files
committed
Fixing constructor bug pytensor<..., 0>
1 parent9aa58f8 commitcfb08ae

File tree

4 files changed

+39
-4
lines changed

4 files changed

+39
-4
lines changed

‎include/xtensor-python/pytensor.hpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,26 @@ namespace pybind11
100100
}
101101
};
102102

103-
}
103+
}// namespace detail
104104
}
105105

106106
namespacext
107107
{
108+
namespacedetail {
109+
110+
template<std::size_t N,typename =void>
111+
structnumpy_strides
112+
{
113+
npy_intp value[N];
114+
};
115+
116+
template<std::size_t N>
117+
structnumpy_strides<N,typename std::enable_if_t<(N ==0)>::type>
118+
{
119+
npy_intp* value =nullptr;
120+
};
121+
122+
}// namespace detail
108123

109124
template<classT, std::size_t N, layout_type L>
110125
structxiterable_inner_types<pytensor<T, N, L>>
@@ -433,8 +448,8 @@ namespace xt
433448
template<classT, std::size_t N, layout_type L>
434449
inlinevoid pytensor<T, N, L>::init_tensor(const shape_type& shape,const strides_type& strides)
435450
{
436-
npy_intp python_strides[N];
437-
std::transform(strides.begin(), strides.end(), python_strides,
451+
detail::numpy_strides<N> python_strides;
452+
std::transform(strides.begin(), strides.end(), python_strides.value,
438453
[](auto v) {returnsizeof(T) * v; });
439454
int flags = NPY_ARRAY_ALIGNED;
440455
if (!std::is_const<T>::value)
@@ -445,7 +460,7 @@ namespace xt
445460

446461
auto tmp = pybind11::reinterpret_steal<pybind11::object>(
447462
PyArray_NewFromDescr(&PyArray_Type, (PyArray_Descr*) dtype.release().ptr(),static_cast<int>(shape.size()),
448-
const_cast<npy_intp*>(shape.data()), python_strides,
463+
const_cast<npy_intp*>(shape.data()), python_strides.value,
449464
nullptr, flags,nullptr));
450465

451466
if (!tmp)

‎test/test_pytensor.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,15 @@ namespace xt
6565
EXPECT_THROW(pyt3::from_shape(shp), std::runtime_error);
6666
}
6767

68+
TEST(pytensor, scalar_from_shape)
69+
{
70+
std::array<size_t,0> shape;
71+
auto a = pytensor<double,0>::from_shape(shape);
72+
pytensor<double,0>b(1.2);
73+
EXPECT_TRUE(a.size() == b.size());
74+
EXPECT_TRUE(xt::has_shape(a, b.shape()));
75+
}
76+
6877
TEST(pytensor, strided_constructor)
6978
{
7079
central_major_result<container_type> cmr;

‎test_python/main.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,11 @@ void col_major_array(xt::pyarray<double, xt::layout_type::column_major>& arg)
227227
}
228228
}
229229

230+
xt::pytensor<int,0>xscalar(const xt::pytensor<int,1>& arg)
231+
{
232+
returnxt::sum(arg);
233+
}
234+
230235
template<classT>
231236
using ndarray = xt::pyarray<T, xt::layout_type::row_major>;
232237

@@ -285,6 +290,8 @@ PYBIND11_MODULE(xtensor_python_test, m)
285290
m.def("col_major_array", col_major_array);
286291
m.def("row_major_tensor", row_major_tensor);
287292

293+
m.def("xscalar", xscalar);
294+
288295
py::class_<C>(m,"C")
289296
.def(py::init<>())
290297
.def_property_readonly(

‎test_python/test_pyarray.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,10 @@ def test_col_row_major(self):
151151
xt.col_major_array(varF)
152152
xt.col_major_array(varF[:, :,0])# still col major!
153153

154+
deftest_xscalar(self):
155+
var=np.arange(50,dtype=int)
156+
self.assertTrue(np.sum(var)==xt.xscalar(var))
157+
154158
deftest_bad_argument_call(self):
155159
withself.assertRaises(TypeError):
156160
xt.simple_array("foo")

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp