Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit1be71a7

Browse files
committed
Adding rank member
1 parentd6f87cf commit1be71a7

File tree

4 files changed

+58
-2
lines changed

4 files changed

+58
-2
lines changed

‎include/xtensor-python/pyarray.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ namespace xt
158158
using inner_shape_type =typename base_type::inner_shape_type;
159159
using inner_strides_type =typename base_type::inner_strides_type;
160160
using inner_backstrides_type =typename base_type::inner_backstrides_type;
161+
constexprstatic std::size_t rank = SIZE_MAX;
161162

162163
pyarray();
163164
pyarray(const value_type& t);
@@ -514,7 +515,7 @@ namespace xt
514515
{
515516
return;
516517
}
517-
518+
518519
m_shape =inner_shape_type(reinterpret_cast<size_type*>(PyArray_SHAPE(this->python_array())),
519520
static_cast<size_type>(PyArray_NDIM(this->python_array())));
520521
m_strides =inner_strides_type(reinterpret_cast<difference_type*>(PyArray_STRIDES(this->python_array())),

‎include/xtensor-python/pytensor.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ namespace xt
168168
using inner_shape_type =typename base_type::inner_shape_type;
169169
using inner_strides_type =typename base_type::inner_strides_type;
170170
using inner_backstrides_type =typename base_type::inner_backstrides_type;
171+
constexprstatic std::size_t rank = N;
171172

172173
pytensor();
173174
pytensor(nested_initializer_list_t<T, N> t);
@@ -471,7 +472,7 @@ namespace xt
471472
{
472473
return;
473474
}
474-
475+
475476
if (PyArray_NDIM(this->python_array()) != N)
476477
{
477478
throwstd::runtime_error("NumPy: ndarray has incorrect number of dimensions");

‎test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ set(XTENSOR_PYTHON_TESTS
8686
test_pyarray.cpp
8787
test_pytensor.cpp
8888
test_pyvectorize.cpp
89+
test_sfinae.cpp
8990
)
9091

9192
add_executable(test_xtensor_python${XTENSOR_PYTHON_TESTS}${XTENSOR_PYTHON_HEADERS})

‎test/test_sfinae.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/***************************************************************************
2+
* Copyright (c) Wolf Vollprecht, Johan Mabille and Sylvain Corlay *
3+
* Copyright (c) QuantStack *
4+
* *
5+
* Distributed under the terms of the BSD 3-Clause License. *
6+
* *
7+
* The full license is in the file LICENSE, distributed with this software. *
8+
****************************************************************************/
9+
10+
#include<limits>
11+
12+
#include"gtest/gtest.h"
13+
#include"xtensor-python/pytensor.hpp"
14+
#include"xtensor-python/pyarray.hpp"
15+
#include"xtensor/xarray.hpp"
16+
#include"xtensor/xtensor.hpp"
17+
18+
namespacext
19+
{
20+
template<classE, std::enable_if_t<!xt::has_fixed_rank_t<E>::value,int> =0>
21+
inlineboolsfinae_has_fixed_rank(E&&)
22+
{
23+
returnfalse;
24+
}
25+
26+
template<classE, std::enable_if_t<xt::has_fixed_rank_t<E>::value,int> =0>
27+
inlineboolsfinae_has_fixed_rank(E&&)
28+
{
29+
returntrue;
30+
}
31+
32+
TEST(sfinae, fixed_rank)
33+
{
34+
xt::pyarray<size_t> a = {{9,9,9}, {9,9,9}};
35+
xt::pytensor<size_t,1> b = {9,9};
36+
xt::pytensor<size_t,2> c = {{9,9}, {9,9}};
37+
38+
EXPECT_TRUE(sfinae_has_fixed_rank(a) ==false);
39+
EXPECT_TRUE(sfinae_has_fixed_rank(b) ==true);
40+
EXPECT_TRUE(sfinae_has_fixed_rank(c) ==true);
41+
}
42+
43+
TEST(sfinae, get_rank)
44+
{
45+
xt::pytensor<double,1> A = xt::zeros<double>({2});
46+
xt::pytensor<double,2> B = xt::zeros<double>({2,2});
47+
xt::pyarray<double> C = xt::zeros<double>({2,2});
48+
49+
EXPECT_TRUE(xt::get_rank<decltype(A)>::value ==1ul);
50+
EXPECT_TRUE(xt::get_rank<decltype(B)>::value ==2ul);
51+
EXPECT_TRUE(xt::get_rank<decltype(C)>::value == SIZE_MAX);
52+
}
53+
}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp