Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit06d86c9

Browse files
authored
Merge pull request#267 from tdegeus/qad
Adding possibility to 'cast' or copy to `xt::xarray` etc
2 parents43b244e +af91def commit06d86c9

File tree

9 files changed

+276
-18
lines changed

9 files changed

+276
-18
lines changed

‎.azure-pipelines/unix-build.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,17 @@ steps:
4545
displayName: Example - readme 1
4646
workingDirectory: $(Build.SourcesDirectory)/docs/source/examples/readme_example_1
4747
48+
-script:|
49+
source activate xtensor-python
50+
cmake -Bbuild -DPython_EXECUTABLE=`which python`
51+
cd build
52+
cmake --build .
53+
cp ../example.py .
54+
python example.py
55+
cd ..
56+
displayName: Example - Copy 'cast'
57+
workingDirectory: $(Build.SourcesDirectory)/docs/source/examples/copy_cast
58+
4859
-script:|
4960
source activate xtensor-python
5061
cmake -Bbuild -DPython_EXECUTABLE=`which python`

‎docs/source/examples.rst

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,5 +143,54 @@ Then we can test the module:
143143
Since we did not install the module,
144144
we should compile and run the example from the same folder.
145145
To install, please consult
146-
`this*pybind11* /*CMake* example<https://github.com/pybind/cmake_example>`_.
146+
`this pybind11 / CMake example<https://github.com/pybind/cmake_example>`_.
147147
**Tip**: take care to modify that example with the correct *CMake* case ``Python_EXECUTABLE``.
148+
149+
Fall-back cast
150+
==============
151+
152+
The previous example showed you how to design your module to be flexible in accepting data.
153+
From C++ we used ``xt::xarray<double>``,
154+
whereas for the Python API we used ``xt::pyarray<double>`` to operate directly on the memory
155+
of a NumPy array from Python (without copying the data).
156+
157+
Sometimes, you might not have the flexibility to design your module's methods
158+
with template parameters.
159+
This might occur when you want to ``override`` functions
160+
(though it is recommended to use CRTP to still use templates).
161+
In this case we can still bind the module in Python using *xtensor-python*,
162+
however, we have to copy the data from a (NumPy) array.
163+
This means that although the following signatures are quite different when used from C++,
164+
as follows:
165+
166+
1. *Constant reference*: read from the data, without copying it.
167+
168+
..code-block::cpp
169+
170+
void foo(const xt::xarray<double>& a);
171+
172+
2. *Reference*: read from and/or write to the data, without copying it.
173+
174+
..code-block::cpp
175+
176+
void foo(xt::xarray<double>& a);
177+
178+
3. *Copy*: copy the data.
179+
180+
..code-block::cpp
181+
182+
void foo(xt::xarray<double> a);
183+
184+
The Python will all cases result in a copy to a temporary variable
185+
(though the last signature will lead to a copy to a temporary variable, and another copy to ``a``).
186+
On the one hand, this is more costly than when using ``xt::pyarray`` and ``xt::pyxtensor``,
187+
on the other hand, it means that all changes you make to a reference, are made to the temporary
188+
copy, and are thus lost.
189+
190+
Still, it might be a convenient way to create Python bindings, using a minimal effort.
191+
Consider this example:
192+
193+
:download:`main.cpp<examples/copy_cast/main.cpp>`
194+
195+
..literalinclude::examples/copy_cast/main.cpp
196+
:language: cpp
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
cmake_minimum_required(VERSION 3.1..3.19)
2+
3+
project(mymodule)
4+
5+
find_package(pybind11 CONFIG REQUIRED)
6+
find_package(xtensor REQUIRED)
7+
find_package(xtensor-python REQUIRED)
8+
find_package(Python REQUIRED COMPONENTS NumPy)
9+
10+
pybind11_add_module(mymodule main.cpp)
11+
target_link_libraries(mymodulePUBLIC pybind11::module xtensor-python Python::NumPy)
12+
13+
target_compile_definitions(mymodulePRIVATE VERSION_INFO=0.1.0)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
importmymodule
2+
importnumpyasnp
3+
4+
c=np.array([[1,2,3], [4,5,6]])
5+
assertnp.isclose(np.sum(np.sin(c)),mymodule.sum_of_sines(c))
6+
assertnp.isclose(np.sum(np.cos(c)),mymodule.sum_of_cosines(c))
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#include<numeric>
2+
#include<xtensor.hpp>
3+
#include<pybind11/pybind11.h>
4+
#defineFORCE_IMPORT_ARRAY
5+
#include<xtensor-python/pyarray.hpp>
6+
7+
template<classT>
8+
doublesum_of_sines(T& m)
9+
{
10+
auto sines =xt::sin(m);// sines does not actually hold values.
11+
returnstd::accumulate(sines.begin(), sines.end(),0.0);
12+
}
13+
14+
// In the Python API this a reference to a temporary variable
15+
doublesum_of_cosines(const xt::xarray<double>& m)
16+
{
17+
auto cosines =xt::cos(m);// cosines does not actually hold values.
18+
returnstd::accumulate(cosines.begin(), cosines.end(),0.0);
19+
}
20+
21+
PYBIND11_MODULE(mymodule, m)
22+
{
23+
xt::import_numpy();
24+
m.doc() ="Test module for xtensor python bindings";
25+
m.def("sum_of_sines", sum_of_sines<xt::pyarray<double>>,"Sum the sines of the input values");
26+
m.def("sum_of_cosines", sum_of_cosines,"Sum the cosines of the input values");
27+
}

‎include/xtensor-python/pynative_casters.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
#include"xtensor_type_caster_base.hpp"
1414

15-
1615
namespacepybind11
1716
{
1817
namespacedetail

‎include/xtensor-python/xtensor_type_caster_base.hpp

Lines changed: 121 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,97 @@ namespace pybind11
2323
{
2424
namespacedetail
2525
{
26+
template<typename T, xt::layout_type L>
27+
structpybind_array_getter_impl
28+
{
29+
staticautorun(handle src)
30+
{
31+
returnarray_t<T, array::c_style | array::forcecast>::ensure(src);
32+
}
33+
};
34+
35+
template<typename T>
36+
structpybind_array_getter_impl<T, xt::layout_type::column_major>
37+
{
38+
staticautorun(handle src)
39+
{
40+
returnarray_t<T, array::f_style | array::forcecast>::ensure(src);
41+
}
42+
};
43+
44+
template<classT>
45+
structpybind_array_getter
46+
{
47+
};
48+
49+
template<classT, xt::layout_type L>
50+
structpybind_array_getter<xt::xarray<T, L>>
51+
{
52+
staticautorun(handle src)
53+
{
54+
return pybind_array_getter_impl<T, L>::run(src);
55+
}
56+
};
57+
58+
template<classT, std::size_t N, xt::layout_type L>
59+
structpybind_array_getter<xt::xtensor<T, N, L>>
60+
{
61+
staticautorun(handle src)
62+
{
63+
return pybind_array_getter_impl<T, L>::run(src);
64+
}
65+
};
66+
67+
template<classCT,classS, xt::layout_type L,classFST>
68+
structpybind_array_getter<xt::xstrided_view<CT, S, L, FST>>
69+
{
70+
staticautorun(handle/*src*/)
71+
{
72+
returnfalse;
73+
}
74+
};
75+
76+
template<classEC, xt::layout_type L,classSC,classTag>
77+
structpybind_array_getter<xt::xarray_adaptor<EC, L, SC, Tag>>
78+
{
79+
staticautorun(handle src)
80+
{
81+
auto buf = pybind_array_getter_impl<EC, L>::run(src);
82+
return buf;
83+
}
84+
};
85+
86+
template<classEC, std::size_t N, xt::layout_type L,classTag>
87+
structpybind_array_getter<xt::xtensor_adaptor<EC, N, L, Tag>>
88+
{
89+
staticautorun(handle/*src*/)
90+
{
91+
returnfalse;
92+
}
93+
};
94+
95+
96+
template<classT>
97+
structpybind_array_dim_checker
98+
{
99+
template<classB>
100+
staticboolrun(const B& buf)
101+
{
102+
returntrue;
103+
}
104+
};
105+
106+
template<classT, std::size_t N, xt::layout_type L>
107+
structpybind_array_dim_checker<xt::xtensor<T, N, L>>
108+
{
109+
template<classB>
110+
staticboolrun(const B& buf)
111+
{
112+
return buf.ndim() == N;
113+
}
114+
};
115+
116+
26117
// Casts a strided expression type to numpy array.If given a base,
27118
// the numpy array references the src data, otherwise it'll make a copy.
28119
// The writeable attributes lets you specify writeable flag for the array.
@@ -74,10 +165,6 @@ namespace pybind11
74165
template<classType>
75166
structxtensor_type_caster_base
76167
{
77-
boolload(handle/*src*/,bool)
78-
{
79-
returnfalse;
80-
}
81168

82169
private:
83170

@@ -106,6 +193,36 @@ namespace pybind11
106193

107194
public:
108195

196+
PYBIND11_TYPE_CASTER(Type, _("numpy.ndarray[") + npy_format_descriptor<typename Type::value_type>::name + _("]"));
197+
198+
boolload(handle src,bool convert)
199+
{
200+
using T =typename Type::value_type;
201+
202+
if (!convert && !array_t<T>::check_(src))
203+
{
204+
returnfalse;
205+
}
206+
207+
auto buf = pybind_array_getter<Type>::run(src);
208+
209+
if (!buf)
210+
{
211+
returnfalse;
212+
}
213+
if (!pybind_array_dim_checker<Type>::run(buf))
214+
{
215+
returnfalse;
216+
}
217+
218+
std::vector<size_t>shape(buf.ndim());
219+
std::copy(buf.shape(), buf.shape() + buf.ndim(), shape.begin());
220+
value =Type::from_shape(shape);
221+
std::copy(buf.data(), buf.data() + buf.size(), value.data());
222+
223+
returntrue;
224+
}
225+
109226
// Normal returned non-reference, non-const value:
110227
static handlecast(Type&& src, return_value_policy/* policy*/, handle parent)
111228
{
@@ -151,18 +268,6 @@ namespace pybind11
151268
{
152269
returncast_impl(src, policy, parent);
153270
}
154-
155-
#ifdef PYBIND11_DESCR// The macro is removed from pybind11 since 2.3
156-
static PYBIND11_DESCRname()
157-
{
158-
return_("xt::xtensor");
159-
}
160-
#else
161-
staticconstexprauto name = _("xt::xtensor");
162-
#endif
163-
164-
template<typename T>
165-
using cast_op_type = cast_op_type<T>;
166271
};
167272
}
168273
}

‎test_python/main.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,33 @@ xt::pyarray<double> example2(xt::pyarray<double>& m)
3333
return m +2;
3434
}
3535

36+
xt::xarray<int>example3_xarray(const xt::xarray<int>& m)
37+
{
38+
returnxt::transpose(m) +2;
39+
}
40+
41+
xt::xarray<int, xt::layout_type::column_major>example3_xarray_colmajor(
42+
const xt::xarray<int, xt::layout_type::column_major>& m)
43+
{
44+
returnxt::transpose(m) +2;
45+
}
46+
47+
xt::xtensor<int,3>example3_xtensor3(const xt::xtensor<int,3>& m)
48+
{
49+
returnxt::transpose(m) +2;
50+
}
51+
52+
xt::xtensor<int,2>example3_xtensor2(const xt::xtensor<int,2>& m)
53+
{
54+
returnxt::transpose(m) +2;
55+
}
56+
57+
xt::xtensor<int,2, xt::layout_type::column_major>example3_xtensor2_colmajor(
58+
const xt::xtensor<int,2, xt::layout_type::column_major>& m)
59+
{
60+
returnxt::transpose(m) +2;
61+
}
62+
3663
// Readme Examples
3764

3865
doublereadme_example1(xt::pyarray<double>& m)
@@ -249,6 +276,11 @@ PYBIND11_MODULE(xtensor_python_test, m)
249276

250277
m.def("example1", example1);
251278
m.def("example2", example2);
279+
m.def("example3_xarray", example3_xarray);
280+
m.def("example3_xarray_colmajor", example3_xarray_colmajor);
281+
m.def("example3_xtensor3", example3_xtensor3);
282+
m.def("example3_xtensor2", example3_xtensor2);
283+
m.def("example3_xtensor2_colmajor", example3_xtensor2_colmajor);
252284

253285
m.def("complex_overload", no_complex_overload);
254286
m.def("complex_overload", complex_overload);

‎test_python/test_pyarray.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,22 @@ def test_example2(self):
3636
y=xt.example2(x)
3737
np.testing.assert_allclose(y,res,1e-12)
3838

39+
deftest_example3(self):
40+
x=np.arange(2*3).reshape(2,3)
41+
xc=np.asfortranarray(x)
42+
y=np.arange(2*3*4).reshape(2,3,4)
43+
v=y[1:,1:,0]
44+
z=np.arange(2*3*4*5).reshape(2,3,4,5)
45+
np.testing.assert_array_equal(xt.example3_xarray(x),x.T+2)
46+
np.testing.assert_array_equal(xt.example3_xarray_colmajor(xc),xc.T+2)
47+
np.testing.assert_array_equal(xt.example3_xtensor3(y),y.T+2)
48+
np.testing.assert_array_equal(xt.example3_xtensor2(x),x.T+2)
49+
np.testing.assert_array_equal(xt.example3_xtensor2(y[1:,1:,0]),v.T+2)
50+
np.testing.assert_array_equal(xt.example3_xtensor2_colmajor(xc),xc.T+2)
51+
52+
withself.assertRaises(TypeError):
53+
xt.example3_xtensor3(x)
54+
3955
deftest_vectorize(self):
4056
x1=np.array([[0,1], [2,3]])
4157
x2=np.array([0,1])

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp