diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 50a7d57346a6..5cbb5e620c88 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -6588,15 +6588,20 @@ FailureOr> relayout(RewriteContext &ctx, /*use_implicit_shape=*/true); } if (src.layout_rank() >= dst.layout_rank() && !src.offsets()[0].has_value() && - !src.offsets()[1].has_value() && src.tilesPerVreg(target_shape) == 1) { + !src.offsets()[1].has_value()) { // A fully replicated value is always easy to relayout - // It would be nice to be able to assert this here, but given replicated - // values our rules can introduce equivalent expressions. - // assert all(t is src_tiles_list[0] for t in src_tiles_list) xla::Array dst_tiles( - /*sizes=*/dst.tileArrayShape(vty.getShape(), target_shape), - /*value=*/src_tiles.data()[0]); - return assemble_with_mask_check(dst_tiles); + dst.tileArrayImplicitShape(vty.getShape(), target_shape)); + SmallVector idxs; + dst_tiles.Each([&](const absl::Span src_idx, Value *vreg) { + idxs.assign(src_idx.begin(), src_idx.end()); + dst.eraseImplicit(idxs); + src.insertImplicit(idxs, 0); + *(idxs.end() - 2) = 0; + *(idxs.end() - 1) = 0; + *vreg = src_tiles(idxs); + }); + return assemble_with_mask_check(dst_tiles, /*use_implicit_shape=*/true); } // Consider (1,128),-2 -> (8,128). In this case we can change the implicit