@@ -338,10 +338,13 @@ LogicalResult MemRefSqueezeOp::verify() {
338338
339339auto source_shape = source_type.getShape ();
340340auto target_shape = target_type.getShape ();
341- auto squeezed_or =
342- computeSqueezedDimsChecked (*this , source_shape, target_shape);
343- if (failed (squeezed_or)) {
344- return failure ();
341+ FAILUREOR_ASSIGN_OR_RETURN (
342+ auto squeezed,
343+ computeSqueezedDimsChecked (*this , source_shape, target_shape));
344+ if (squeezed.empty () && source_shape != target_shape) {
345+ return emitOpError (
346+ " Source and target shapes must be the same if no dimensions are"
347+ " squeezed." );
345348 }
346349
347350auto source_layout = source_type.getLayout ();
@@ -359,7 +362,7 @@ LogicalResult MemRefSqueezeOp::verify() {
359362 }
360363 SmallVector<int64_t > target_strides;
361364for (auto [i, stride] :llvm::enumerate (source_strides)) {
362- if (!llvm::is_contained (*squeezed_or , i)) {
365+ if (!llvm::is_contained (squeezed , i)) {
363366 target_strides.push_back (stride);
364367 }
365368 }
@@ -369,49 +372,44 @@ LogicalResult MemRefSqueezeOp::verify() {
369372return emitOpError (" Layout mismatch: got" )
370373 << target_layout <<" , expected" << expected_layout <<" ." ;
371374 }
372- }
373-
374- auto erase_layout_op =getInput ().getDefiningOp <tpu::EraseLayoutOp>();
375- if (!erase_layout_op) {
375+ return success ();
376+ }else if (!isa<TiledLayoutAttr>(source_layout) &&
377+ !isa<TiledLayoutAttr>(target_layout)) {
378+ // TODO(slebedev): Remove this branch once we migrate to TPU dialect layout
379+ // attribute on SC.
376380return success ();
377381 }
378382
379- auto layout_ref = erase_layout_op.getOperand ();
380- MemRefType layout_ty =getMemRefType (layout_ref);
381- auto layout_attr = dyn_cast<tpu::TiledLayoutAttr>(layout_ty.getLayout ());
382- if (!layout_attr) {
383- return emitOpError (
384- " Input from EraseLayoutOp is expected to have a TiledLayoutAttr." );
385- }
386- auto &squeezed = squeezed_or.value ();
387- if (squeezed.empty () && source_shape != target_shape) {
388- return failure ();
389- }
390-
391- auto tiles = layout_attr.getTiles ();
392- if (tiles.size () ==1 ) {
393- auto tile = layout_attr.getTiles ().front ();
394- auto tile_dims = tile.dimensions ();
395- int first_tiled = source_shape.size () - tile_dims.size ();
396- for (int dim : squeezed) {
397- if (dim >= first_tiled) {
398- int tile_idx = dim - first_tiled;
399- if (tile_idx <0 || tile_idx >=static_cast <int >(tile_dims.size ())) {
400- return emitOpError () <<" Internal error: tile index out of bounds." ;
401- }
402- if (tile_dims[tile_idx] !=1 ) {
403- return emitOpError ()
404- <<" All tiled squeezed dimensions must be of size 1." ;
383+ auto tiles = cast<TiledLayoutAttr>(source_layout).getTiles ();
384+ switch (tiles.size ()) {
385+ case 0 :
386+ break ;
387+ case 1 : {
388+ auto tile = tiles.front ();
389+ auto tile_dims = tile.dimensions ();
390+ int first_tiled = source_shape.size () - tile_dims.size ();
391+ for (int dim : squeezed) {
392+ if (dim >= first_tiled) {
393+ int tile_idx = dim - first_tiled;
394+ if (tile_idx <0 || tile_idx >=static_cast <int >(tile_dims.size ())) {
395+ return emitOpError () <<" Internal error: tile index out of bounds." ;
396+ }
397+ if (tile_dims[tile_idx] !=1 ) {
398+ return emitOpError ()
399+ <<" All tiled squeezed dimensions must be of size 1." ;
400+ }
405401 }
406402 }
403+ break ;
407404 }
408- }else {
409- auto first_tile = tiles.front ();
410- for (int dim : squeezed) {
411- int first_tiled = source_shape.size () - first_tile.dimensions ().size ();
412- if (dim >= first_tiled) {
413- return emitOpError () <<" When multiple tiles are present, no tiled"
414- " dimensions can be squeezed." ;
405+ default : {
406+ auto first_tile = tiles.front ();
407+ for (int dim : squeezed) {
408+ int first_tiled = source_shape.size () - first_tile.dimensions ().size ();
409+ if (dim >= first_tiled) {
410+ return emitOpError () <<" When multiple tiles are present, no tiled"
411+ " dimensions can be squeezed." ;
412+ }
415413 }
416414 }
417415 }