|
| 1 | +#ifndef UTENSOR_TRANSPOSE_H |
| 2 | +#defineUTENSOR_TRANSPOSE_H |
| 3 | + |
| 4 | +#include<cstring> |
| 5 | + |
| 6 | +#include"context.hpp" |
| 7 | +#include"operatorBase.hpp" |
| 8 | +#include"tensor.hpp" |
| 9 | +#include"types.hpp" |
| 10 | +#include"uTensor_util.hpp" |
| 11 | + |
| 12 | +namespaceuTensor { |
| 13 | +namespaceReferenceOperators { |
| 14 | + |
| 15 | +// Transpose (Swap Axes) as a port from Numpy |
| 16 | +// using stride interation in the order of transpose axes |
| 17 | +template<typename Tin> |
| 18 | +classTransposeOperator :publicOperatorInterface<2,1> { |
| 19 | +/* reshape input as the shape of output*/ |
| 20 | +public: |
| 21 | +enum names_in :uint8_t { input, perm }; |
| 22 | +enum names_out :uint8_t { output }; |
| 23 | + |
| 24 | +virtualvoidcompute() { |
| 25 | +const Tensor& perm_tensor = inputs[perm].tensor(); |
| 26 | +if (perm_tensor.get_shape().num_dims() >1) { |
| 27 | +uTensor_printf( |
| 28 | +"the input tensor perm should be a vector (dimension should be 1)\n"); |
| 29 | +Context::get_default_context()->throwError(new InvalidTensorInputError); |
| 30 | + } |
| 31 | +if (perm_tensor->get_type() !=i32) { |
| 32 | +uTensor_printf("expecting perm tensor of element type int32_t\n"); |
| 33 | +Context::get_default_context()->throwError( |
| 34 | +new InvalidTensorDataTypeError); |
| 35 | + } |
| 36 | + Tensor& input_tensor = inputs[input].tensor(); |
| 37 | + TensorShape& input_shape = input_tensor.get_shape(); |
| 38 | + input_shape.update_dims(); |
| 39 | + |
| 40 | +// Strides are used to iterate over the dataset, and transfer |
| 41 | +// the input tensor data, into the output tensor |
| 42 | + TensorStrides input_strides =TensorStrides(input_shape); |
| 43 | + |
| 44 | + Tensor& output_tensor = outputs[output].tensor(); |
| 45 | + |
| 46 | +// Create a placeholder to calculate the output shape |
| 47 | +// Normally this would reference output shape, but since this could (usually |
| 48 | +// would) be referencing the input, let's keep a dedicated value |
| 49 | + TensorShape output_shape =TensorShape(1,1,1,1); |
| 50 | + TensorStrides output_strides =TensorStrides(output_shape); |
| 51 | + TensorShape offsets =TensorShape(input_shape.num_dims()); |
| 52 | + |
| 53 | +for (size_t i =0; i <4; ++i) { |
| 54 | + output_shape[i] =0; |
| 55 | + output_strides[i] =0; |
| 56 | + |
| 57 | +// Offsets are used to avoid multiple for loops |
| 58 | + offsets[i] =0; |
| 59 | + } |
| 60 | + |
| 61 | +for (size_t i =0; i < (size_t)input_shape.num_dims(); ++i) { |
| 62 | +int32_t axis =static_cast<int32_t>(perm_tensor(i)); |
| 63 | + output_shape[axis] = input_shape[i]; |
| 64 | + |
| 65 | +// output_strides(i) is derived from axes and input_strides |
| 66 | + output_strides[axis] = input_strides[i]; |
| 67 | + } |
| 68 | + |
| 69 | +// Output shape can be asserted once the transform |
| 70 | +// effect has been determined |
| 71 | + output_shape.update_dims(); |
| 72 | + output_tensor->resize(output_shape); |
| 73 | + |
| 74 | +// Perform some basic checks |
| 75 | +if (input_tensor->num_elems() != output_tensor->num_elems()) { |
| 76 | +uTensor_printf("inconsistent input and output shape for reshape\n"); |
| 77 | +Context::get_default_context()->throwError(new InvalidReshapeError); |
| 78 | +return; |
| 79 | + } |
| 80 | +if (input_tensor->get_type() != output_tensor->get_type()) { |
| 81 | +uTensor_printf("inconsistent input and output data type for reshape\n"); |
| 82 | +Context::get_default_context()->throwError( |
| 83 | +new InvalidTensorDataTypeError); |
| 84 | +return; |
| 85 | + } |
| 86 | +if (!_check_input_shape()) { |
| 87 | +Context::get_default_context()->throwError( |
| 88 | +new InvalidTensorDataTypeError); |
| 89 | +return; |
| 90 | + } |
| 91 | + |
| 92 | +// copy data |
| 93 | +for (uint32_t i =0; i < input_tensor->num_elems(); ++i) { |
| 94 | +// Index of the source value, must be calculated |
| 95 | +// using the output strides and output shape |
| 96 | +uint32_t idx =0; |
| 97 | +for (uint32_t j =0; j < output_shape.num_dims(); j++) { |
| 98 | + idx += offsets[j] * output_strides[j]; |
| 99 | + } |
| 100 | + |
| 101 | +// this is not copy: `output_tensor(i) = input_tensor(i);` |
| 102 | +output_tensor(i) =static_cast<Tin>(input_tensor(idx)); |
| 103 | + |
| 104 | +// Update offsets, to iterate sequentially along strides |
| 105 | +// in the order of axes |
| 106 | +for (int32_t j = output_shape.num_dims() -1; j >=0; j--) { |
| 107 | + offsets[j] = (offsets[j] +1) % (output_shape[j]); |
| 108 | +if (offsets[j] >0) { |
| 109 | +break; |
| 110 | + } |
| 111 | + } |
| 112 | + } |
| 113 | + } |
| 114 | + |
| 115 | +private: |
| 116 | +bool_check_input_shape() { |
| 117 | +const Tensor& input_tensor = inputs[input].tensor(); |
| 118 | +const TensorShape& shape = input_tensor->get_shape(); |
| 119 | +uint8_t num_dims = shape.num_dims(); |
| 120 | +for (int i =0; i < num_dims; ++i) { |
| 121 | +if (shape[i] <0) { |
| 122 | +uTensor_printf("the output shape must be all positive\n"); |
| 123 | +returnfalse; |
| 124 | + } |
| 125 | + } |
| 126 | +returntrue; |
| 127 | + } |
| 128 | +}; |
| 129 | + |
| 130 | +}// namespace ReferenceOperators |
| 131 | +}// namespace uTensor |
| 132 | + |
| 133 | +#endif// UTENSOR_TRANSPOSE_H |