@@ -124,85 +124,4 @@ REGISTER_OP("TPUPartitionedInput")
124124return OkStatus ();
125125 });
126126
127- REGISTER_OP (" TPUPartitionedInputV2" )
128- .Input(" inputs: N * T" )
129- .Output(" output: T" )
130- .Attr(" N: int >= 1" )
131- .Attr(" T: type" )
132- .Attr(" partition_dims: list(int)" )
133- .Attr(" is_packed: bool = false" )
134- .SetShapeFn([](InferenceContext* c) {
135- DataType dtype;
136- TF_RETURN_IF_ERROR (c->GetAttr (" T" , &dtype));
137- std::vector<int > partition_dims;
138- TF_RETURN_IF_ERROR (c->GetAttr (" partition_dims" , &partition_dims));
139- bool is_packed;
140- TF_RETURN_IF_ERROR (c->GetAttr (" is_packed" , &is_packed));
141-
142- int num_partitions =1 ;
143- for (const int & partition_dim : partition_dims) {
144- num_partitions *= partition_dim;
145- }
146-
147- bool replicated = partition_dims.empty ();
148- int num_inputs_expected = is_packed ?1 : num_partitions;
149- if (!((replicated && !is_packed) ||
150- (c->num_inputs () == num_inputs_expected))) {
151- // we cannot validate the number of inputs for replicated, unpacked ops
152- // since we cannot infer the number of partitions from partition_dims
153- return errors::InvalidArgument (" Expected" , num_inputs_expected,
154- " inputs, got" , c->num_inputs ()," ." );
155- }else if (c->num_inputs () ==0 ) {
156- return errors::InvalidArgument (
157- " Expected at least one input to TPUPartitionedInputV2." );
158- }
159-
160- ShapeHandle output_shape;
161- if (dtype == DT_RESOURCE) {
162- ShapeHandle previous_shape_handle;
163- const std::vector<shape_inference::ShapeAndType>* shapes_and_types =
164- nullptr ;
165- for (int i = c->num_inputs () -1 ; i >=0 ; --i) {
166- shapes_and_types = c->input_handle_shapes_and_types (i);
167- if (shapes_and_types) {
168- ShapeHandle shape_handle = shapes_and_types->at (0 ).shape ;
169- if (!c->FullyDefined (shape_handle)) {
170- return errors::InvalidArgument (" Inputs must have static shape," ,
171- " input[" , i,
172- " ] has unknown dimension." );
173- }
174-
175- if (i != c->num_inputs () -1 ) {
176- ShapeHandle tmp;
177- if (!c->Merge (shape_handle, previous_shape_handle, &tmp).ok ()) {
178- return errors::InvalidArgument (
179- " Inputs must have the same shape." );
180- }
181- }else {
182- previous_shape_handle = shape_handle;
183- }
184- }
185- }
186-
187- if (shapes_and_types) {
188- TF_ASSIGN_OR_RETURN (
189- output_shape,
190- _ComputeOutputShape (c, previous_shape_handle, partition_dims));
191- std::vector<shape_inference::ShapeAndType> output_shapes_and_types;
192- output_shapes_and_types.push_back (shape_inference::ShapeAndType (
193- output_shape, shapes_and_types->at (0 ).dtype ));
194- c->set_output_handle_shapes_and_types (0 , output_shapes_and_types);
195- }
196- }
197-
198- if (!c->FullyDefined (output_shape)) {
199- TF_ASSIGN_OR_RETURN (
200- output_shape,_ComputeOutputShape (c, c->input (0 ), partition_dims));
201- }
202-
203- c->set_output (0 , output_shape);
204-
205- return OkStatus ();
206- });
207-
208127}// namespace tensorflow