diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index c1fba60f4cc5..ffcc8d52cd05 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -790,6 +790,7 @@ def ApplyVectorLayoutPass : Pass<"tpu-apply-vector-layout", "::mlir::func::FuncO Option<"mxu_contracting_size", "mxu-contracting-size", "int", /*default=*/"128", "">, Option<"mxu_noncontracting_size", "mxu-noncontracting-size", "int", /*default=*/"128", "">, Option<"max_sublanes_in_scratch", "max-sublanes-in-scratch", "int", /*default=*/"0", "">, + Option<"vmem_banks", "vmem-banks", "int", /*default=*/"-1", "">, ]; } diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index 510bd384d656..00bd15b57153 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -62,6 +62,7 @@ struct ApplyVectorLayoutContext { // mxu_shape = {contracting_size, non_contracting_size} std::array mxu_shape = {128, 128}; int64_t max_sublanes_in_scratch = 0; + int64_t vmem_banks = -1; // -1 means "unspecified". }; std::pair mightCommunicateBetweenChips(Operation* op); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 951ed59865c7..d3e1b59afe16 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -11,6 +11,7 @@ #include #include #include +#include #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVectorExtras.h" @@ -46,6 +47,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "absl/algorithm/container.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -139,18 +141,21 @@ void moveAllRegions(Operation &src, Operation &dst) { // // Returns: // A memref of the requested shape and type. -FailureOr getInternalScratch(RewriteContext &ctx, OpBuilder &builder, - Location loc, ArrayRef shape, - Type elem_ty) { +FailureOr> getInternalScratch( + RewriteContext &ctx, OpBuilder &builder, Location loc, + ArrayRef shape, Type elem_ty, int64_t sublane_tiling = 0) { if (shape.empty()) { return failure(); } if (shape.back() % ctx.target_shape[1] != 0) { return failure(); } - int sublane_count = + int packing = 32 / elem_ty.getIntOrFloatBitWidth(); + int sublane_count = llvm::divideCeil( std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()) / - ctx.target_shape[1]; + ctx.target_shape[1], + packing); + if (sublane_count > ctx.max_sublanes_in_scratch) { return failure(); } @@ -159,7 +164,7 @@ FailureOr getInternalScratch(RewriteContext &ctx, OpBuilder &builder, FAILUREOR_ASSIGN_OR_RETURN( MemRefType scratch_ref_ty, inferMemref(MemRefType::get(shape, elem_ty), ctx.hardware_generation, - /*tpu_tiling_flags=*/{})); + /*tpu_tiling_flags=*/{}, sublane_tiling)); return builder.create(loc, scratch_ref_ty) .getResult(); } @@ -4752,30 +4757,6 @@ xla::Array retileToReducedSublanes( return dst_vreg_array; } -// Returns true iff the layout changes involve reduced sublanes per tile. -// -// Arguments: -// src: The existing layout. -// dst: The new layout based on which the retiling is to be carried out. -bool isSupportedReducedSublanesRetile( - const VectorLayout &src, const VectorLayout &dst, - const std::array target_shape) { - return src.implicit_dim() == dst.implicit_dim() && - llvm::all_of(llvm::zip_equal(src.offsets(), dst.offsets()), - [](auto tup) { - auto [lhs, rhs] = tup; - return lhs.value_or(0) == rhs.value_or(0); - }) - // TODO (kumudbhandari): We have not tested any tile size where - // tile[-1] != TARGET_SHAPE.lanes. It should work but needs to be - // tested. - && src.tiling()[1] == target_shape[1] && - dst.tiling()[1] == target_shape[1] && - dst.tiling()[0] < src.tiling()[0] && - src.bitwidth() == dst.bitwidth() && - llvm::isPowerOf2_64(src.tiling()[0]) && - llvm::isPowerOf2_64(dst.tiling()[0]); -} // Copy one sublane from a vreg to another vreg. // @@ -5368,13 +5349,353 @@ FailureOr>> changeOffsets( return std::make_pair(dst, std::move(vregs)); } -// TODO(b/265133506): Generalize retiling. +LogicalResult retileToLargeTileWithScratch( + RewriteContext &ctx, OpBuilder &builder, const Location loc, + xla::Array &dst_tiles, const std::array &dst_tile, + const xla::Array &src_tiles, const std::array &src_tile, + TypedValue scratch_ref) { + if (dst_tile[0] % src_tile[0] != 0) { + return failure(); + } + // Number of src vregs needed to assemble one dst vreg. + int vregs_per_group = dst_tile[0] / src_tile[0]; + // Number of sublanes needed per src vreg to assemble one dst vreg. + int sl_per_vreg = ctx.target_shape[0] / vregs_per_group; + int stride = vregs_per_group; + + xla::Array sublane_offsets( + {ctx.target_shape[0] / dst_tile[0], src_tile[0], vregs_per_group}, 0); + absl::c_iota(sublane_offsets, 0); + // The older hardware has limited support for shuffles so even if we have bank + // conflicts, we just accept them and will have the lowering unroll the + // loads/stores. + bool should_handle_bank_confict = + ctx.hardware_generation >= 4 && ctx.vmem_banks > 0 && + ctx.vmem_banks < stride * ctx.target_shape[0]; + // Add one extra sublane to stride to avoid bank conflict. + if (should_handle_bank_confict) { + // Adjust sublane offsets to match the stride. + for (int i = 0; i < sublane_offsets.num_elements(); i += 1) { + *(sublane_offsets.begin() + i) += i / stride; + } + stride += 1; + } + sublane_offsets.TransposeDimensions({0, 2, 1}); + + auto mlirIndexConst = [&](int d) { + return builder.create( + src_tiles.begin()->getLoc(), + builder.getIntegerAttr(builder.getIndexType(), d)); + }; + auto cst_0 = mlirIndexConst(0); + // Each group has exact number of src vregs needed to assemble one dst vreg. + // We can not use circular buffer here because we need to have enough space to + // strided load/store. + int64_t sublanes_per_group = stride * sl_per_vreg * vregs_per_group; + int64_t max_groups_in_scratch = + ctx.max_sublanes_in_scratch / sublanes_per_group; + if (max_groups_in_scratch < 1) { + return emitError(loc, + "scratch space is not enough for retiling to large tile"); + } + int64_t stored_group_cnt = 0; + auto dst_vreg_ty = src_tiles.begin()->getType(); + // Create a new vreg type that can be stored in scratch memref. + auto temp_vreg_ty = + VectorType::get(ctx.target_shape, scratch_ref.getType().getElementType()); + SmallVector sublane_mask(ctx.target_shape[0], true); + // (dst_vreg, load_offset) + std::vector> delayed_loads; + delayed_loads.reserve(max_groups_in_scratch * vregs_per_group); + // We only emit the loads when we run out of scratch space or we are at the + // last vreg of the batch to help bundle scheduling. + auto emit_all_delayed_loads = [&]() { + for (auto [dst_vreg, load_offset] : delayed_loads) { + Value load_op = builder.create( + loc, temp_vreg_ty, scratch_ref, ArrayRef({load_offset, cst_0}), + ArrayRef(sublane_mask), + ArrayRef(sublane_offsets.begin(), sublane_offsets.end())); + *dst_vreg = builder.create(loc, dst_vreg_ty, load_op); + } + delayed_loads.clear(); + }; + + int rank = src_tiles.dimensions().size(); + if (rank != dst_tiles.dimensions().size()) { + return emitError(loc, "src and dst tiles have different ranks"); + } + for (int i = 0; i < rank - 2; ++i) { + if (src_tiles.dim(i) != dst_tiles.dim(i)) { + return emitError(loc, + "Expected src and dst tiles have same dimension " + "sizes on dim") + << i << ", but got " << src_tiles.dim(i) << " vs " + << dst_tiles.dim(i); + } + } + SmallVector src_idx(rank); + dst_tiles.Each([&](absl::Span dst_idx, Value *dst_vreg) { + int64_t dst_row_idx = *(dst_idx.end() - 2); + int64_t dst_col_idx = *(dst_idx.end() - 1); + int64_t vreg_idx_in_group = dst_col_idx % vregs_per_group; + int64_t load_offset = sublanes_per_group * stored_group_cnt + + vreg_idx_in_group * sl_per_vreg * stride; + delayed_loads.push_back( + std::make_pair(dst_vreg, mlirIndexConst(load_offset))); + // When dst vreg is at the last vreg of the group or the current dst + // vregs' row, this indicates we have scheduled delayed loads for all + // the vregs from current group and now we need to store corresponding + // group of src vregs before actually emitting the loads. + if (vreg_idx_in_group == vregs_per_group - 1 || + dst_col_idx == dst_tiles.dimensions().back() - 1) { + auto src_row_idx = dst_row_idx * vregs_per_group; + auto src_col_idx = dst_col_idx / vregs_per_group; + std::copy(dst_idx.begin(), dst_idx.end(), src_idx.begin()); + for (int vi = 0; vi < vregs_per_group; ++vi) { + if (src_row_idx + vi >= src_tiles.dim(rank - 2) || + src_col_idx >= src_tiles.dim(rank - 1)) { + break; + } + *(src_idx.end() - 2) = src_row_idx + vi; + *(src_idx.end() - 1) = src_col_idx; + Value src_vreg = src_tiles(src_idx); + src_vreg = + builder.create(loc, temp_vreg_ty, src_vreg); + Value store_offset = + mlirIndexConst(sublanes_per_group * stored_group_cnt + vi); + builder.create( + loc, src_vreg, scratch_ref, ArrayRef({store_offset, cst_0}), + ArrayRef(sublane_mask), + /*mask=*/nullptr, builder.getI32IntegerAttr(stride)); + } + stored_group_cnt = (stored_group_cnt + 1) % max_groups_in_scratch; + // We emit loads when we run out of scratch space or we are at the + // last vreg of the batch. + if (stored_group_cnt == 0 || + (*(dst_idx.end() - 2) == dst_tiles.dim(rank - 2) - 1 && + *(dst_idx.end() - 1) == dst_tiles.dim(rank - 1) - 1)) { + emit_all_delayed_loads(); + } + } + }); + return success(); +} + +LogicalResult retileToSmallTileWithScratch( + RewriteContext &ctx, OpBuilder &builder, const Location loc, + xla::Array &dst_tiles, const std::array &dst_tile, + const xla::Array &src_tiles, const std::array &src_tile, + TypedValue scratch_ref) { + if (src_tile[0] % dst_tile[0] != 0) { + return failure(); + } + // Number of src vregs needed to assemble one dst vreg. + int vregs_per_group = src_tile[0] / dst_tile[0]; + // Number of sublanes needed per src vreg to assemble one dst vreg. + int sl_per_vreg = ctx.target_shape[0] / vregs_per_group; + int stride = vregs_per_group; + + xla::Array sublane_offsets( + {ctx.target_shape[0] / src_tile[0], dst_tile[0], vregs_per_group}, 0); + absl::c_iota(sublane_offsets, 0); + // The older hardware has limited support for shuffles so even if we have + // bank conflicts, we just accept them and will have the lowering unroll the + // loads/stores. + bool should_handle_bank_confict = + ctx.hardware_generation >= 4 && ctx.vmem_banks > 0 && + ctx.vmem_banks < stride * ctx.target_shape[0]; + bool use_shuffled_load = false; + if (ctx.hardware_generation <= 4) { + if (src_tile[0] == 8) { + // The older hardware does not support shuffled store. However, if the src + // tile is (8, 128), we can convert (shuffled store + strided load) to + // (strided store + shuffled load). + use_shuffled_load = true; + } else if (src_tile[0] == 4) { + // In this case, the trick of replacing a shuffled store with a shuffled + // load does not work. Handling bank conflicts will cause the sublane + // offsets to increase which might make emulation harder, so we avoid + // doing so. + should_handle_bank_confict = false; + } + } + + // Add one extra sublane to stride to avoid bank conflict. + if (should_handle_bank_confict) { + // Adjust sublane offsets to match the stride. + for (int i = 0; i < sublane_offsets.num_elements(); i += 1) { + *(sublane_offsets.begin() + i) += i / stride; + } + stride += 1; + } + sublane_offsets.TransposeDimensions({0, 2, 1}); + auto mlirIndexConst = [&](int d) { + return builder.create( + src_tiles.begin()->getLoc(), + builder.getIntegerAttr(builder.getIndexType(), d)); + }; + auto cst_0 = mlirIndexConst(0); + // Each group has exact number of src vregs needed to assemble one dst vreg. + // We can not use circular buffer here because we need to have enough space + // to strided load/store. + int64_t sublanes_per_group = stride * sl_per_vreg * vregs_per_group; + int64_t max_groups_in_scratch = + ctx.max_sublanes_in_scratch / sublanes_per_group; + if (max_groups_in_scratch < 1) { + return emitError(loc, + "scratch space is not enough for retiling to small tile"); + } + int64_t stored_group_cnt = 0; + auto dst_vreg_ty = src_tiles.begin()->getType(); + // Create a new vreg type that can be stored in scratch memref. + auto temp_vreg_ty = + VectorType::get(ctx.target_shape, scratch_ref.getType().getElementType()); + SmallVector sublane_mask(ctx.target_shape[0], true); + // (dst_vreg, load_offset) + std::vector> delayed_loads; + delayed_loads.reserve(max_groups_in_scratch * vregs_per_group); + // We only emit the loads when we run out of scratch space or we are at the + // last vreg of the batch to help bundle scheduling. + auto emit_all_delayed_loads = [&]() { + for (auto [dst_vreg, load_offset] : delayed_loads) { + Value load_op; + if (use_shuffled_load) { + load_op = builder.create( + loc, temp_vreg_ty, scratch_ref, + ArrayRef({load_offset, cst_0}), ArrayRef(sublane_mask), + ArrayRef(sublane_offsets.begin(), sublane_offsets.end())); + } else { + load_op = builder.create( + loc, temp_vreg_ty, scratch_ref, + ArrayRef({load_offset, cst_0}), ArrayRef(sublane_mask), + builder.getI32IntegerAttr(stride)); + } + *dst_vreg = builder.create(loc, dst_vreg_ty, load_op); + } + delayed_loads.clear(); + }; + int rank = src_tiles.dimensions().size(); + if (rank != dst_tiles.dimensions().size()) { + return emitError(loc, "src and dst tiles have different ranks"); + } + for (int i = 0; i < rank - 2; ++i) { + if (src_tiles.dim(i) != dst_tiles.dim(i)) { + return emitError(loc, + "Expected src and dst tiles have same dimension " + "sizes on dim") + << i << ", but got " << src_tiles.dim(i) << " vs " + << dst_tiles.dim(i); + } + } + SmallVector dst_idx(rank); + src_tiles.Each([&](absl::Span src_idx, Value src_vreg) { + int64_t src_row_idx = *(src_idx.end() - 2); + int64_t src_col_idx = *(src_idx.end() - 1); + int64_t vreg_idx_in_group = src_col_idx % vregs_per_group; + src_vreg = builder.create(loc, temp_vreg_ty, src_vreg); + if (use_shuffled_load) { + Value store_offset = mlirIndexConst( + sublanes_per_group * stored_group_cnt + vreg_idx_in_group); + builder.create( + loc, src_vreg, scratch_ref, ArrayRef({store_offset, cst_0}), + ArrayRef(sublane_mask), + /*mask=*/nullptr, builder.getI32IntegerAttr(stride)); + } else { + Value store_offset = + mlirIndexConst(sublanes_per_group * stored_group_cnt + + vreg_idx_in_group * sl_per_vreg * stride); + builder.create( + loc, src_vreg, scratch_ref, ArrayRef({store_offset, cst_0}), + ArrayRef(sublane_mask), + ArrayRef(sublane_offsets.begin(), sublane_offsets.end())); + } + // When src vreg is at the last vreg of the group or the current src + // vregs' row, this indicates we have stored all the vregs needed to + // assemble a new group of dst vreg. + if (vreg_idx_in_group == vregs_per_group - 1 || + src_col_idx == src_tiles.dimensions().back() - 1) { + auto dst_row_idx = src_row_idx * vregs_per_group; + auto dst_col_idx = src_col_idx / vregs_per_group; + std::copy(src_idx.begin(), src_idx.end(), dst_idx.begin()); + for (int vi = 0; vi < vregs_per_group; ++vi) { + if (dst_row_idx + vi >= dst_tiles.dim(rank - 2) || + dst_col_idx >= dst_tiles.dim(rank - 1)) { + break; + } + *(dst_idx.end() - 2) = dst_row_idx + vi; + *(dst_idx.end() - 1) = dst_col_idx; + Value *dst_vreg = &dst_tiles(dst_idx); + int64_t load_offset = + use_shuffled_load ? (sublanes_per_group * stored_group_cnt + + vi * sl_per_vreg * stride) + : (sublanes_per_group * stored_group_cnt + vi); + delayed_loads.push_back( + std::make_pair(dst_vreg, mlirIndexConst(load_offset))); + } + stored_group_cnt = (stored_group_cnt + 1) % max_groups_in_scratch; + // We emit loads when we run out of scratch space or we are at the + // last vreg of the batch. + if (stored_group_cnt == 0 || + (*(src_idx.end() - 2) == src_tiles.dim(rank - 2) - 1 && + *(src_idx.end() - 1) == src_tiles.dim(rank - 1) - 1)) { + emit_all_delayed_loads(); + } + } + }); + return success(); +} + +// go/mosaic-retiling-in-scratch is the full internal documentation that +// includes more details about the TPU generations. +LogicalResult retileWithScratch(RewriteContext &ctx, OpBuilder &builder, + const Location loc, + xla::Array &dst_tiles, + const std::array &dst_tiling, + const xla::Array &src_tiles, + const std::array &src_tiling, + int packing) { + if (!(src_tiling[1] == ctx.target_shape[1] && + dst_tiling[1] == ctx.target_shape[1] && src_tiling[0] % packing == 0 && + dst_tiling[0] % packing == 0)) { + return failure(); + } + // Try to get i32 vector scratch space. Because we will bitcast vregs to + // i32 vregs before using scratch for retiling. Through this way we can + // handle packed types as well. + auto vi32_scratch_ref = getInternalScratch( + ctx, builder, loc, {ctx.max_sublanes_in_scratch, ctx.target_shape[1]}, + builder.getI32Type(), /*sublane_tiling=*/1); + if (failed(vi32_scratch_ref)) { + return emitError(loc, "Failed to get scratch ref for retiling"); + } + auto ref = vi32_scratch_ref.value(); + std::array vi32_dst_tiling = {dst_tiling[0] / packing, + dst_tiling[1]}; + std::array vi32_src_tiling = {src_tiling[0] / packing, + src_tiling[1]}; + if (src_tiling[0] > dst_tiling[0]) { + return retileToSmallTileWithScratch(ctx, builder, loc, dst_tiles, + vi32_dst_tiling, src_tiles, + vi32_src_tiling, ref); + } + if (src_tiling[0] < dst_tiling[0]) { + return retileToLargeTileWithScratch(ctx, builder, loc, dst_tiles, + vi32_dst_tiling, src_tiles, + vi32_src_tiling, ref); + } + dst_tiles = std::move(src_tiles); + return success(); +} + FailureOr>> changeTiling( RewriteContext &ctx, OpBuilder &builder, const Location loc, VectorType vty, const VectorLayout src, xla::Array vregs, const std::array dst_tiling, bool try_replicate_rows) { + bool has_enough_scratch = ctx.max_sublanes_in_scratch >= + ctx.target_shape[0] * (ctx.target_shape[0] + 1); const auto &target_shape = ctx.target_shape; - if (src.tiling() == dst_tiling) { + const std::array src_tiling = src.tiling(); + if (src_tiling == dst_tiling) { return std::pair(src, std::move(vregs)); } const int packing = src.packing(); @@ -5384,106 +5705,62 @@ FailureOr>> changeTiling( if (!dst.isValid(target_shape)) { return emitError(loc, "Not implemented: invalid offsets in tiling target"); } - // Handle retiling from (packing, 128) to (8 * packing, 128). - if (src.offsets() == LayoutOffsets{0, 0} && - src.tiling() == std::array{packing, 128} && - dst_tiling == std::array{8 * packing, 128}) { - bool replicate_sublanes = try_replicate_rows && packing == 1 && - *(vregs.dimensions().end() - 2) == 1; - xla::Array retiled( - dst.tileArrayImplicitShape(vty.getShape(), target_shape)); + auto dst_tiles_shape = + dst.tileArrayImplicitShape(vty.getShape(), target_shape); + // Handle retiling from (1, 128) to (8, 128) for 32-bit data with replicating + // sublanes. + if (try_replicate_rows && packing == 1 && + *(vregs.dimensions().end() - 2) == 1 && + src.offsets() == LayoutOffsets{0, 0} && + src.tiling() == std::array{1, 128} && + dst_tiling == std::array{8, 128}) { + xla::Array retiled(dst_tiles_shape); retiled.Each([&](absl::Span idx, Value *tile) { SmallVector src_idx(idx.begin(), idx.end()); *(src_idx.end() - 2) *= target_shape[0]; *(src_idx.end() - 1) /= target_shape[0]; const int64_t src_sl_idx = *(idx.end() - 1) % target_shape[0]; - if (replicate_sublanes) { - CHECK_EQ(src.getImplicitTiledDims(vty.getShape(), 1)[0], 1); - *tile = - broadcastSublane(builder, vregs(src_idx), src_sl_idx, target_shape); - } else { - for (int dst_sl_idx = 0; - dst_sl_idx < target_shape[0] && - *(src_idx.end() - 2) < *(vregs.dimensions().end() - 2); - ++dst_sl_idx, ++*(src_idx.end() - 2)) { - *tile = copy_one_sublane(builder, vregs(src_idx), src_sl_idx, *tile, - dst_sl_idx, target_shape); - } - } + CHECK_EQ(src.getImplicitTiledDims(vty.getShape(), 1)[0], 1); + *tile = + broadcastSublane(builder, vregs(src_idx), src_sl_idx, target_shape); }); // We have successfully replicated sublanes. - if (replicate_sublanes) { - dst = VectorLayout(bitwidth, {std::nullopt, dst.offsets()[1]}, dst_tiling, - dst.implicit_dim()); - } - return std::pair(dst, std::move(retiled)); - } - // Handle retiling from (m, 128) to (8, 128) for 32-bit data - // where m < 8 and m is a power of 2. - // TODO(b/306692696): Handle any vregs.dimensions(). - if (bitwidth == 32 && src.offsets() == LayoutOffsets{0, 0} && - target_shape[0] % src.tiling()[0] == 0 && - src.tiling()[1] == target_shape[1] && dst.tiling() == target_shape && - *(vregs.dimensions().end() - 2) == 1) { - xla::Array retiled( - dst.tileArrayImplicitShape(vty.getShape(), target_shape)); - retiled.Each([&](const absl::Span idx, - Value *const new_src_tile) { - const int64_t tiles_per_vreg = src.tilesPerVreg(target_shape); - const int64_t dst_col = idx.back(); - const int64_t src_col = dst_col / tiles_per_vreg; - const int64_t start_slane_idx = - src.tiling()[0] * (dst_col % tiles_per_vreg); - SmallVector src_idx(toArrayRef(idx)); - src_idx.back() = src_col; - Value src_tile = vregs(src_idx); - if (start_slane_idx) { - SmallVector slane_idxs; - slane_idxs.reserve(target_shape[0]); - for (int i = 0; i < target_shape[0]; ++i) { - slane_idxs.push_back(start_slane_idx + (i % src.tiling()[0])); - } - const DenseI32ArrayAttr gather_indices = - builder.getDenseI32ArrayAttr(slane_idxs); - *new_src_tile = builder.create(loc, src_tile.getType(), - src_tile, gather_indices, - /*dimension=*/0); - } else { - *new_src_tile = src_tile; - } - }); + dst = VectorLayout(bitwidth, {std::nullopt, dst.offsets()[1]}, dst_tiling, + dst.implicit_dim()); return std::pair(dst, std::move(retiled)); } // (8,128) -> (8 * packing,128) tiling change for packed type. if (bitwidth < 32 && 32 % bitwidth == 0 && - src.tiling() == std::array{8, 128} && - dst.tiling() == std::array{8 * dst.packing(), 128}) { - xla::Array retiled( - dst.tileArrayImplicitShape(vty.getShape(), target_shape)); - int vty_packing = dst.packing(); - VectorType vreg_x32 = - vty.getElementType().isSignlessInteger() - ? VectorType::get(target_shape, builder.getI32Type()) - : VectorType::get(target_shape, builder.getF32Type()); - retiled.Each([&](absl::Span idx, Value *tile) { - const int vreg_part = idx.back() % vty_packing; - SmallVector parts; - parts.reserve(vty_packing); - SmallVector src_idx(idx.begin(), idx.end()); - src_idx[src_idx.size() - 2] *= vty_packing; - src_idx[src_idx.size() - 1] /= vty_packing; - for (int i = 0; i < vty_packing; ++i) { - parts.push_back(builder.create( - loc, vreg_x32, vregs(src_idx), vreg_part)); - if (src_idx[src_idx.size() - 2] < - vregs.dim(vregs.num_dimensions() - 2) - 1) { - ++src_idx[src_idx.size() - 2]; + src_tiling == std::array{8, 128} && + dst_tiling == std::array{8 * dst.packing(), 128}) { + // Note: for int4, retiling with scratch is always faster. + if (bitwidth != 4 || !has_enough_scratch) { + xla::Array retiled(dst_tiles_shape); + int vty_packing = dst.packing(); + VectorType vreg_x32 = + vty.getElementType().isSignlessInteger() + ? VectorType::get(target_shape, builder.getI32Type()) + : VectorType::get(target_shape, builder.getF32Type()); + retiled.Each([&](absl::Span idx, Value *tile) { + const int vreg_part = idx.back() % vty_packing; + SmallVector parts; + parts.reserve(vty_packing); + SmallVector src_idx(idx.begin(), idx.end()); + src_idx[src_idx.size() - 2] *= vty_packing; + src_idx[src_idx.size() - 1] /= vty_packing; + for (int i = 0; i < vty_packing; ++i) { + parts.push_back(builder.create( + loc, vreg_x32, vregs(src_idx), vreg_part)); + if (src_idx[src_idx.size() - 2] < + vregs.dim(vregs.num_dimensions() - 2) - 1) { + ++src_idx[src_idx.size() - 2]; + } } - } - *tile = builder.create( - loc, vregs.begin()->getType(), parts, tpu::PackFormat::kCompressed); - }); - return std::pair(dst, std::move(retiled)); + *tile = builder.create( + loc, vregs.begin()->getType(), parts, tpu::PackFormat::kCompressed); + }); + return std::pair(dst, std::move(retiled)); + } } // Handle retiling from (1, 128 * packing) to (packing, 128) for // packed data. @@ -5497,8 +5774,8 @@ FailureOr>> changeTiling( // match corresponding elements without shifting. It's just that // the tiles are not adjacent (no contiguous vreg slice). if (bitwidth < 32 && 32 % bitwidth == 0 && - src.tiling() == std::array{1, 128 * packing} && - dst.tiling() == std::array{packing, 128}) { + src_tiling == std::array{1, 128 * packing} && + dst_tiling == std::array{packing, 128}) { // To illustrate, consider a 2 x 16 16-bit shape laid out in vregs of // 4 sublanes and 2 lanes (this is convenient for to keep the example small // yet non-trivial) with (1, 4) tiling. We will relayout to (2, 2) tiling. @@ -5539,8 +5816,7 @@ FailureOr>> changeTiling( // [(a b) (A B) (c d) (C D) ...]. That is, traverse down each column before // moving to the next one. This is exactly an interleaving of the sublanes // of the vreg parts. - xla::Array retiled( - dst.tileArrayImplicitShape(vty.getShape(), target_shape)); + xla::Array retiled(dst_tiles_shape); const VectorType vreg_x32 = vty.getElementType().isSignlessInteger() ? VectorType::get(target_shape, builder.getI32Type()) @@ -5565,13 +5841,41 @@ FailureOr>> changeTiling( }); return std::pair(dst, std::move(retiled)); } - if (isSupportedReducedSublanesRetile(src, dst, target_shape)) { - return std::pair(dst, retileToReducedSublanes(builder, vty.getShape(), src, - vregs, dst, target_shape)); + if (src_tiling[1] == target_shape[1] && dst_tiling[1] == target_shape[1]) { + // TODO(b/368088671): When sublane tiling changes, we should be able to + // preserve some replications from the source layout. But we need to + // make sure they are implemented efficiently and well-tested. For now, we + // just simply use 0 for the replicated offset after retiling. + dst = VectorLayout( + bitwidth, {src.offsets()[0].value_or(0), src.offsets()[1].value_or(0)}, + dst_tiling, dst.implicit_dim()); + + // All clauses in the and expression are based on performance benchmarking. + bool use_alu = !has_enough_scratch || + (ctx.hardware_generation >= 5 && src_tiling[0] != packing && + dst_tiling[0] != packing); + + if (use_alu) { + if (src_tiling[0] > dst_tiling[0]) { + return std::pair( + dst, retileToReducedSublanes(builder, vty.getShape(), src, vregs, + dst, target_shape)); + } else if (!has_enough_scratch) { + // TODO(b/357538782): Implement retileToIncreasedSublanes with ALU ops. + return emitError( + loc, + "Not implemented: retiling to increase sublane tiling with ALU"); + } + } + xla::Array retiled(dst_tiles_shape); + if (failed(retileWithScratch(ctx, builder, loc, retiled, dst_tiling, vregs, + src_tiling, packing))) { + return failure(); + } + return std::pair(dst, std::move(retiled)); } return emitError(loc, "Not implemented: Unsupported tiling change for ") - << vty << ": from " << src << " to tiling (" << dst_tiling[0] << ", " - << dst_tiling[1] << ")"; + << vty << ": from " << src << " to " << dst; } FailureOr>> changeImplicitDim( @@ -5878,6 +6182,7 @@ struct ApplyVectorLayoutPass mxu_contracting_size = ctx.mxu_shape[0]; mxu_noncontracting_size = ctx.mxu_shape[1]; max_sublanes_in_scratch = ctx.max_sublanes_in_scratch; + vmem_banks = ctx.vmem_banks; } void runOnOperation() override { // Fail if hardware_generation has not been set from the default value. @@ -5889,7 +6194,8 @@ struct ApplyVectorLayoutPass .hardware_generation = hardware_generation, .target_shape = {sublane_count, lane_count}, .mxu_shape = {mxu_contracting_size, mxu_noncontracting_size}, - .max_sublanes_in_scratch = max_sublanes_in_scratch}; + .max_sublanes_in_scratch = max_sublanes_in_scratch, + .vmem_banks = vmem_banks}; if (failed(applyLayoutFunc(ctx, getOperation()))) { signalPassFailure(); return; diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 84403e41b561..87ccaa644e8c 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -2464,9 +2464,7 @@ def kernel(x_ref, out_ref): )(x) np.testing.assert_array_equal(out, np.broadcast_to(x, (256, 512))) - @only_passes_in_interpret(unless_generation=4) def test_bfloat16_to_uint32_bitcast(self): - """b/347771903""" x = np.arange(16 * 2 * 256, dtype=jnp.bfloat16).reshape(16, 2, 256) def kernel(x_ref, out_ref): @@ -2475,7 +2473,7 @@ def kernel(x_ref, out_ref): out = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((16, 1, 256), jnp.uint32) )(x) - # FIXME: Add correctness test for result. + np.testing.assert_array_equal(out, state_utils.bitcast(x, jnp.uint32)) @only_passes_in_interpret() def test_roll_partial(self): @@ -2548,9 +2546,7 @@ def kernel(x_ref, out_ref): np.testing.assert_array_equal(out, np.reshape(x, (8, 1, 128))) - @only_passes_in_interpret() def test_mixed_strides(self): - """b/352841329""" x = np.zeros((8, 128), dtype=jnp.float32) y = np.zeros((8, 2, 128), dtype=jnp.bfloat16)