Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitceba941

Browse files
authored
Merge pull request#218 from dboyliao/transpose-op
tranpose op: perm as input tensor
2 parentsb311042 +1f40fcb commitceba941

File tree

3 files changed

+88
-68
lines changed

3 files changed

+88
-68
lines changed
Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
11
#ifndef _TRANSPOSE_TEST_H
22
#define_TRANSPOSE_TEST_H
33

4-
staticconstunsignedshort transpose_axes_arr[3] = {2,1,0 };
5-
staticconstfloat random_input_arr[15] = {3.484638214111328,2.033799886703491,3.2437448501586914,4.783249855041504,3.497023582458496,3.511240005493164,1.558927297592163,3.7084484100341797,2.570117712020874,0.2405869960784912,1.8713605403900146,4.19132661819458,0.6596618890762329,0.9029078483581543,0.2223271131515503 };
6-
staticconstfloat ref_output_arr[15] = {3.484638214111328,3.511240005493164,1.8713605403900146,2.033799886703491,1.558927297592163,4.19132661819458,3.2437448501586914,3.7084484100341797,0.6596618890762329,4.783249855041504,2.570117712020874,0.9029078483581543,3.497023582458496,0.2405869960784912,0.2223271131515503 };
4+
staticconstint32_t transpose_perm_arr[4] = {2,1,0,3};
5+
staticconstfloat random_input_arr[15] = {
6+
3.484638214111328,2.033799886703491,3.2437448501586914,
7+
4.783249855041504,3.497023582458496,3.511240005493164,
8+
1.558927297592163,3.7084484100341797,2.570117712020874,
9+
0.2405869960784912,1.8713605403900146,4.19132661819458,
10+
0.6596618890762329,0.9029078483581543,0.2223271131515503};
11+
staticconstfloat ref_output_arr[15] = {
12+
3.484638214111328,3.511240005493164,1.8713605403900146,
13+
2.033799886703491,1.558927297592163,4.19132661819458,
14+
3.2437448501586914,3.7084484100341797,0.6596618890762329,
15+
4.783249855041504,2.570117712020874,0.9029078483581543,
16+
3.497023582458496,0.2405869960784912,0.2223271131515503};
717

8-
#endif//_TRANSPOSE
18+
#endif//_TRANSPOSE_TEST_H

‎TESTS/operators/test_transpose.cpp‎

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
#include<iostream>
33

44
#include"RamTensor.hpp"
5-
#include"Transpose.hpp"
65
#include"RomTensor.hpp"
6+
#include"Transpose.hpp"
77
#include"arenaAllocator.hpp"
88
#include"constants_transpose.hpp"
99
#include"context.hpp"
@@ -19,19 +19,19 @@ TEST(Transpose, transpose_test) {
1919
localCircularArenaAllocator<15 *2 *sizeof(float),uint32_t> ram_allocator;
2020
Context::get_default_context()->set_metadata_allocator(&meta_allocator);
2121
Context::get_default_context()->set_ram_data_allocator(&ram_allocator);
22-
22+
2323
Tensor input_tensor =newRomTensor({3,1,5,1}, flt, random_input_arr);
24+
Tensor perm_tensor =newRomTensor({4},i32, transpose_perm_arr);
2425

2526
TensorShapeinput_target_shape(3,1,5,1);
2627
TensorShape input_shape = input_tensor->get_shape();
2728
EXPECT_TRUE(input_target_shape == input_shape);
2829

29-
Tensor transpose_axes =newRomTensor({4},u8, transpose_axes_arr);
3030
Tensor output_tensor =newRamTensor(flt);
31-
TransposeOperator<float>op({2,1,0,3});
32-
31+
TransposeOperator<float> op;
3332

34-
op.set_inputs({{TransposeOperator<float>::input, input_tensor}})
33+
op.set_inputs({{TransposeOperator<float>::input, input_tensor},
34+
{TransposeOperator<float>::perm, perm_tensor}})
3535
.set_outputs({{TransposeOperator<float>::output, output_tensor}})
3636
.eval();
3737

‎src/uTensor/ops/Transpose.hpp‎

Lines changed: 68 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,38 @@
11
#ifndef UTENSOR_TRANSPOSE_H
22
#defineUTENSOR_TRANSPOSE_H
33

4+
#include<cstring>
5+
46
#include"context.hpp"
5-
#include"types.hpp"
7+
#include"operatorBase.hpp"
68
#include"tensor.hpp"
9+
#include"types.hpp"
710
#include"uTensor_util.hpp"
8-
#include"operatorBase.hpp"
9-
10-
#include<cstring>
1111

1212
namespaceuTensor {
1313
namespaceReferenceOperators {
1414

1515
// Transpose (Swap Axes) as a port from Numpy
1616
// using stride interation in the order of transpose axes
1717
template<typename Tin>
18-
classTransposeOperator :publicOperatorInterface<1,1> {
19-
/* reshape input as the shape of output*/
20-
public:
21-
TransposeOperator(const TensorShape&& axes) : _axes(axes) {}
22-
TransposeOperator(const TensorShape& axes) : _axes(axes) {}
23-
24-
enum names_in :uint8_t { input };
18+
classTransposeOperator :publicOperatorInterface<2,1> {
19+
/* reshape input as the shape of output*/
20+
public:
21+
enum names_in :uint8_t { input, perm };
2522
enum names_out :uint8_t { output };
2623

27-
virtualvoidcompute(){
24+
virtualvoidcompute() {
25+
const Tensor& perm_tensor = inputs[perm].tensor();
26+
if (perm_tensor.get_shape().num_dims() >1) {
27+
uTensor_printf(
28+
"the input tensor perm should be a vector (dimension should be 1)\n");
29+
Context::get_default_context()->throwError(new InvalidTensorInputError);
30+
}
31+
if (perm_tensor->get_type() !=i32) {
32+
uTensor_printf("expecting perm tensor of element type int32_t\n");
33+
Context::get_default_context()->throwError(
34+
new InvalidTensorDataTypeError);
35+
}
2836
Tensor& input_tensor = inputs[input].tensor();
2937
TensorShape& input_shape = input_tensor.get_shape();
3038
input_shape.update_dims();
@@ -36,78 +44,80 @@ class TransposeOperator : public OperatorInterface<1, 1> {
3644
Tensor& output_tensor = outputs[output].tensor();
3745

3846
// Create a placeholder to calculate the output shape
39-
// Normally this would reference output shape, but since this could (usually would) be referencing the input, let's keep a dedicated value
40-
TensorShape output_shape =TensorShape(1,1,1,1);
47+
// Normally this would reference output shape, but since this could (usually
48+
// would) be referencing the input, let's keep a dedicated value
49+
TensorShape output_shape =TensorShape(1,1,1,1);
4150
TensorStrides output_strides =TensorStrides(output_shape);
4251
TensorShape offsets =TensorShape(input_shape.num_dims());
4352

44-
for (size_t i =0; i <4; ++i) {
53+
for (size_t i =0; i <4; ++i) {
4554
output_shape[i] =0;
4655
output_strides[i] =0;
4756

4857
// Offsets are used to avoid multiple for loops
4958
offsets[i] =0;
5059
}
5160

52-
for (size_t i =0; i < (size_t) input_shape.num_dims(); ++i) {
53-
output_shape[_axes[i]] = input_shape[i];
61+
for (size_t i =0; i < (size_t)input_shape.num_dims(); ++i) {
62+
int32_t axis =static_cast<int32_t>(perm_tensor(i));
63+
output_shape[axis] = input_shape[i];
5464

5565
// output_strides(i) is derived from axes and input_strides
56-
output_strides[_axes[i]] = input_strides[i];
66+
output_strides[axis] = input_strides[i];
5767
}
58-
59-
// Output shape can be asserted once the transform
68+
69+
// Output shape can be asserted once the transform
6070
// effect has been determined
6171
output_shape.update_dims();
6272
output_tensor->resize(output_shape);
6373

6474
// Perform some basic checks
65-
if (input_tensor->num_elems() != output_tensor->num_elems()){
66-
uTensor_printf("inconsistent input and output shape for reshape\n");
67-
Context::get_default_context()->throwError(new InvalidReshapeError);
68-
return;
69-
}
70-
if (input_tensor->get_type() != output_tensor->get_type()){
71-
uTensor_printf("inconsistent input and output data type for reshape\n");
72-
Context::get_default_context()->throwError(new InvalidTensorDataTypeError);
73-
return;
75+
if (input_tensor->num_elems() != output_tensor->num_elems()) {
76+
uTensor_printf("inconsistent input and output shape for reshape\n");
77+
Context::get_default_context()->throwError(new InvalidReshapeError);
78+
return;
79+
}
80+
if (input_tensor->get_type() != output_tensor->get_type()) {
81+
uTensor_printf("inconsistent input and output data type for reshape\n");
82+
Context::get_default_context()->throwError(
83+
new InvalidTensorDataTypeError);
84+
return;
7485
}
75-
if (!_check_input_shape()){
76-
Context::get_default_context()->throwError(new InvalidTensorDataTypeError);
77-
return;
86+
if (!_check_input_shape()) {
87+
Context::get_default_context()->throwError(
88+
new InvalidTensorDataTypeError);
89+
return;
7890
}
7991

8092
// copy data
81-
for (uint32_t i =0; i < input_tensor->num_elems(); ++i) {
82-
// Index of the source value, must be calculated
83-
// using the output strides and output shape
84-
uint32_t idx =0;
85-
for (uint32_t j =0; j < output_shape.num_dims(); j++) {
86-
idx += offsets[j] * output_strides[j];
87-
}
88-
89-
// this is not copy: `output_tensor(i) = input_tensor(i);`
90-
output_tensor(i) =static_cast<Tin>(input_tensor(idx));
93+
for (uint32_t i =0; i < input_tensor->num_elems(); ++i) {
94+
// Index of the source value, must be calculated
95+
// using the output strides and output shape
96+
uint32_t idx =0;
97+
for (uint32_t j =0; j < output_shape.num_dims(); j++) {
98+
idx += offsets[j] * output_strides[j];
99+
}
91100

92-
// Update offsets, to iterate sequentially along strides
93-
// in the order of axes
94-
for (int32_t j = output_shape.num_dims() -1; j >=0; j--) {
95-
offsets[j] = (offsets[j] +1) % (output_shape[j]);
96-
if( offsets[j] >0 ) {
97-
break;
98-
}
99-
}
100-
}
101+
// this is not copy: `output_tensor(i) = input_tensor(i);`
102+
output_tensor(i) =static_cast<Tin>(input_tensor(idx));
101103

104+
// Update offsets, to iterate sequentially along strides
105+
// in the order of axes
106+
for (int32_t j = output_shape.num_dims() -1; j >=0; j--) {
107+
offsets[j] = (offsets[j] +1) % (output_shape[j]);
108+
if (offsets[j] >0) {
109+
break;
110+
}
111+
}
112+
}
102113
}
103-
private:
104-
TensorShape _axes;
105114

106-
bool_check_input_shape(){
115+
private:
116+
bool_check_input_shape() {
107117
const Tensor& input_tensor = inputs[input].tensor();
108118
const TensorShape& shape = input_tensor->get_shape();
109119
uint8_t num_dims = shape.num_dims();
110-
for (int i =0; i < num_dims; ++i){
120+
for (int i =0; i < num_dims; ++i){
111121
if (shape[i] <0) {
112122
uTensor_printf("the output shape must be all positive\n");
113123
returnfalse;
@@ -117,7 +127,7 @@ class TransposeOperator : public OperatorInterface<1, 1> {
117127
}
118128
};
119129

120-
}
121-
}
130+
}// namespace ReferenceOperators
131+
}// namespace uTensor
122132

123-
#endif// UTENSOR_TRANSPOSE_H
133+
#endif// UTENSOR_TRANSPOSE_H

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp