11#ifndef UTENSOR_TRANSPOSE_H
22#define UTENSOR_TRANSPOSE_H
33
4+ #include < cstring>
5+
46#include " context.hpp"
5- #include " types .hpp"
7+ #include " operatorBase .hpp"
68#include " tensor.hpp"
9+ #include " types.hpp"
710#include " uTensor_util.hpp"
8- #include " operatorBase.hpp"
9-
10- #include < cstring>
1111
1212namespace uTensor {
1313namespace ReferenceOperators {
1414
1515// Transpose (Swap Axes) as a port from Numpy
1616// using stride interation in the order of transpose axes
1717template <typename Tin>
18- class TransposeOperator :public OperatorInterface <1 ,1 > {
19- /* reshape input as the shape of output*/
20- public:
21- TransposeOperator (const TensorShape&& axes) : _axes(axes) {}
22- TransposeOperator (const TensorShape& axes) : _axes(axes) {}
23-
24- enum names_in :uint8_t { input };
18+ class TransposeOperator :public OperatorInterface <2 ,1 > {
19+ /* reshape input as the shape of output*/
20+ public:
21+ enum names_in :uint8_t { input, perm };
2522enum names_out :uint8_t { output };
2623
27- virtual void compute (){
24+ virtual void compute () {
25+ const Tensor& perm_tensor = inputs[perm].tensor ();
26+ if (perm_tensor.get_shape ().num_dims () >1 ) {
27+ uTensor_printf (
28+ " the input tensor perm should be a vector (dimension should be 1)\n " );
29+ Context::get_default_context ()->throwError (new InvalidTensorInputError);
30+ }
31+ if (perm_tensor->get_type () !=i32 ) {
32+ uTensor_printf (" expecting perm tensor of element type int32_t\n " );
33+ Context::get_default_context ()->throwError (
34+ new InvalidTensorDataTypeError);
35+ }
2836 Tensor& input_tensor = inputs[input].tensor ();
2937 TensorShape& input_shape = input_tensor.get_shape ();
3038 input_shape.update_dims ();
@@ -36,78 +44,80 @@ class TransposeOperator : public OperatorInterface<1, 1> {
3644 Tensor& output_tensor = outputs[output].tensor ();
3745
3846// Create a placeholder to calculate the output shape
39- // Normally this would reference output shape, but since this could (usually would) be referencing the input, let's keep a dedicated value
40- TensorShape output_shape =TensorShape (1 ,1 ,1 ,1 );
47+ // Normally this would reference output shape, but since this could (usually
48+ // would) be referencing the input, let's keep a dedicated value
49+ TensorShape output_shape =TensorShape (1 ,1 ,1 ,1 );
4150 TensorStrides output_strides =TensorStrides (output_shape);
4251 TensorShape offsets =TensorShape (input_shape.num_dims ());
4352
44- for (size_t i =0 ; i <4 ; ++i) {
53+ for (size_t i =0 ; i <4 ; ++i) {
4554 output_shape[i] =0 ;
4655 output_strides[i] =0 ;
4756
4857// Offsets are used to avoid multiple for loops
4958 offsets[i] =0 ;
5059 }
5160
52- for (size_t i =0 ; i < (size_t ) input_shape.num_dims (); ++i) {
53- output_shape[_axes[i]] = input_shape[i];
61+ for (size_t i =0 ; i < (size_t )input_shape.num_dims (); ++i) {
62+ int32_t axis =static_cast <int32_t >(perm_tensor (i));
63+ output_shape[axis] = input_shape[i];
5464
5565// output_strides(i) is derived from axes and input_strides
56- output_strides[_axes[i] ] = input_strides[i];
66+ output_strides[axis ] = input_strides[i];
5767 }
58-
59- // Output shape can be asserted once the transform
68+
69+ // Output shape can be asserted once the transform
6070// effect has been determined
6171 output_shape.update_dims ();
6272 output_tensor->resize (output_shape);
6373
6474// Perform some basic checks
65- if (input_tensor->num_elems () != output_tensor->num_elems ()){
66- uTensor_printf (" inconsistent input and output shape for reshape\n " );
67- Context::get_default_context ()->throwError (new InvalidReshapeError);
68- return ;
69- }
70- if (input_tensor->get_type () != output_tensor->get_type ()){
71- uTensor_printf (" inconsistent input and output data type for reshape\n " );
72- Context::get_default_context ()->throwError (new InvalidTensorDataTypeError);
73- return ;
75+ if (input_tensor->num_elems () != output_tensor->num_elems ()) {
76+ uTensor_printf (" inconsistent input and output shape for reshape\n " );
77+ Context::get_default_context ()->throwError (new InvalidReshapeError);
78+ return ;
79+ }
80+ if (input_tensor->get_type () != output_tensor->get_type ()) {
81+ uTensor_printf (" inconsistent input and output data type for reshape\n " );
82+ Context::get_default_context ()->throwError (
83+ new InvalidTensorDataTypeError);
84+ return ;
7485 }
75- if (!_check_input_shape ()){
76- Context::get_default_context ()->throwError (new InvalidTensorDataTypeError);
77- return ;
86+ if (!_check_input_shape ()) {
87+ Context::get_default_context ()->throwError (
88+ new InvalidTensorDataTypeError);
89+ return ;
7890 }
7991
8092// copy data
81- for (uint32_t i =0 ; i < input_tensor->num_elems (); ++i) {
82- // Index of the source value, must be calculated
83- // using the output strides and output shape
84- uint32_t idx =0 ;
85- for (uint32_t j =0 ; j < output_shape.num_dims (); j++) {
86- idx += offsets[j] * output_strides[j];
87- }
88-
89- // this is not copy: `output_tensor(i) = input_tensor(i);`
90- output_tensor (i) =static_cast <Tin>(input_tensor (idx));
93+ for (uint32_t i =0 ; i < input_tensor->num_elems (); ++i) {
94+ // Index of the source value, must be calculated
95+ // using the output strides and output shape
96+ uint32_t idx =0 ;
97+ for (uint32_t j =0 ; j < output_shape.num_dims (); j++) {
98+ idx += offsets[j] * output_strides[j];
99+ }
91100
92- // Update offsets, to iterate sequentially along strides
93- // in the order of axes
94- for (int32_t j = output_shape.num_dims () -1 ; j >=0 ; j--) {
95- offsets[j] = (offsets[j] +1 ) % (output_shape[j]);
96- if ( offsets[j] >0 ) {
97- break ;
98- }
99- }
100- }
101+ // this is not copy: `output_tensor(i) = input_tensor(i);`
102+ output_tensor (i) =static_cast <Tin>(input_tensor (idx));
101103
104+ // Update offsets, to iterate sequentially along strides
105+ // in the order of axes
106+ for (int32_t j = output_shape.num_dims () -1 ; j >=0 ; j--) {
107+ offsets[j] = (offsets[j] +1 ) % (output_shape[j]);
108+ if (offsets[j] >0 ) {
109+ break ;
110+ }
111+ }
112+ }
102113 }
103- private:
104- TensorShape _axes;
105114
106- bool _check_input_shape (){
115+ private:
116+ bool _check_input_shape () {
107117const Tensor& input_tensor = inputs[input].tensor ();
108118const TensorShape& shape = input_tensor->get_shape ();
109119uint8_t num_dims = shape.num_dims ();
110- for (int i =0 ; i < num_dims; ++i){
120+ for (int i =0 ; i < num_dims; ++i) {
111121if (shape[i] <0 ) {
112122uTensor_printf (" the output shape must be all positive\n " );
113123return false ;
@@ -117,7 +127,7 @@ class TransposeOperator : public OperatorInterface<1, 1> {
117127 }
118128};
119129
120- }
121- }
130+ }// namespace ReferenceOperators
131+ }// namespace uTensor
122132
123- #endif // UTENSOR_TRANSPOSE_H
133+ #endif // UTENSOR_TRANSPOSE_H