Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
/jaxPublic

Commite75f4de

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[mosaic] tpu.memref_squeeze now verifies squeezed dimensions even when the source is not produced via tpu.erase_memref_layout
PiperOrigin-RevId: 837889666
1 parent063fbee commite75f4de

File tree

1 file changed

+40
-42
lines changed

1 file changed

+40
-42
lines changed

‎jaxlib/mosaic/dialect/tpu/tpu_ops.cc‎

Lines changed: 40 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -338,10 +338,13 @@ LogicalResult MemRefSqueezeOp::verify() {
338338

339339
auto source_shape = source_type.getShape();
340340
auto target_shape = target_type.getShape();
341-
auto squeezed_or =
342-
computeSqueezedDimsChecked(*this, source_shape, target_shape);
343-
if (failed(squeezed_or)) {
344-
returnfailure();
341+
FAILUREOR_ASSIGN_OR_RETURN(
342+
auto squeezed,
343+
computeSqueezedDimsChecked(*this, source_shape, target_shape));
344+
if (squeezed.empty() && source_shape != target_shape) {
345+
returnemitOpError(
346+
"Source and target shapes must be the same if no dimensions are"
347+
"squeezed.");
345348
}
346349

347350
auto source_layout = source_type.getLayout();
@@ -359,7 +362,7 @@ LogicalResult MemRefSqueezeOp::verify() {
359362
}
360363
SmallVector<int64_t> target_strides;
361364
for (auto [i, stride] :llvm::enumerate(source_strides)) {
362-
if (!llvm::is_contained(*squeezed_or, i)) {
365+
if (!llvm::is_contained(squeezed, i)) {
363366
target_strides.push_back(stride);
364367
}
365368
}
@@ -369,49 +372,44 @@ LogicalResult MemRefSqueezeOp::verify() {
369372
returnemitOpError("Layout mismatch: got")
370373
<< target_layout <<", expected" << expected_layout <<".";
371374
}
372-
}
373-
374-
auto erase_layout_op =getInput().getDefiningOp<tpu::EraseLayoutOp>();
375-
if (!erase_layout_op) {
375+
returnsuccess();
376+
}elseif (!isa<TiledLayoutAttr>(source_layout) &&
377+
!isa<TiledLayoutAttr>(target_layout)) {
378+
// TODO(slebedev): Remove this branch once we migrate to TPU dialect layout
379+
// attribute on SC.
376380
returnsuccess();
377381
}
378382

379-
auto layout_ref = erase_layout_op.getOperand();
380-
MemRefType layout_ty =getMemRefType(layout_ref);
381-
auto layout_attr = dyn_cast<tpu::TiledLayoutAttr>(layout_ty.getLayout());
382-
if (!layout_attr) {
383-
returnemitOpError(
384-
"Input from EraseLayoutOp is expected to have a TiledLayoutAttr.");
385-
}
386-
auto &squeezed = squeezed_or.value();
387-
if (squeezed.empty() && source_shape != target_shape) {
388-
returnfailure();
389-
}
390-
391-
auto tiles = layout_attr.getTiles();
392-
if (tiles.size() ==1) {
393-
auto tile = layout_attr.getTiles().front();
394-
auto tile_dims = tile.dimensions();
395-
int first_tiled = source_shape.size() - tile_dims.size();
396-
for (int dim : squeezed) {
397-
if (dim >= first_tiled) {
398-
int tile_idx = dim - first_tiled;
399-
if (tile_idx <0 || tile_idx >=static_cast<int>(tile_dims.size())) {
400-
returnemitOpError() <<"Internal error: tile index out of bounds.";
401-
}
402-
if (tile_dims[tile_idx] !=1) {
403-
returnemitOpError()
404-
<<"All tiled squeezed dimensions must be of size 1.";
383+
auto tiles = cast<TiledLayoutAttr>(source_layout).getTiles();
384+
switch (tiles.size()) {
385+
case0:
386+
break;
387+
case1: {
388+
auto tile = tiles.front();
389+
auto tile_dims = tile.dimensions();
390+
int first_tiled = source_shape.size() - tile_dims.size();
391+
for (int dim : squeezed) {
392+
if (dim >= first_tiled) {
393+
int tile_idx = dim - first_tiled;
394+
if (tile_idx <0 || tile_idx >=static_cast<int>(tile_dims.size())) {
395+
returnemitOpError() <<"Internal error: tile index out of bounds.";
396+
}
397+
if (tile_dims[tile_idx] !=1) {
398+
returnemitOpError()
399+
<<"All tiled squeezed dimensions must be of size 1.";
400+
}
405401
}
406402
}
403+
break;
407404
}
408-
}else {
409-
auto first_tile = tiles.front();
410-
for (int dim : squeezed) {
411-
int first_tiled = source_shape.size() - first_tile.dimensions().size();
412-
if (dim >= first_tiled) {
413-
returnemitOpError() <<"When multiple tiles are present, no tiled"
414-
"dimensions can be squeezed.";
405+
default: {
406+
auto first_tile = tiles.front();
407+
for (int dim : squeezed) {
408+
int first_tiled = source_shape.size() - first_tile.dimensions().size();
409+
if (dim >= first_tiled) {
410+
returnemitOpError() <<"When multiple tiles are present, no tiled"
411+
"dimensions can be squeezed.";
412+
}
415413
}
416414
}
417415
}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp