From ed4958cb3eddde33dbfc3230d9594e7163410658 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Mon, 17 Jun 2024 17:28:49 -0700 Subject: [PATCH] [XLA:Mosaic] Add internal scratch VMEM - Make internal scratch size configurable. - Pass the number of max sublanes allowed in scratch to apply-vector-layout pass. - Create a helper function to fetch internal scratch VMEM address. PiperOrigin-RevId: 644184896 --- jaxlib/mosaic/dialect/tpu/tpu.td | 7 +++ jaxlib/mosaic/dialect/tpu/tpu_dialect.h | 3 +- .../tpu/transforms/apply_vector_layout.cc | 43 +++++++++++++++++-- .../tpu/transforms/apply_vector_layout.h | 1 + 4 files changed, 49 insertions(+), 5 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 458ce9fc9ca2..3a9633fe4c4b 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -605,6 +605,12 @@ def TPU_GetIterationBoundOp : TPU_Op<"iteration_bound"> { let assemblyFormat = [{ $dim attr-dict `:` type($result) }]; } +def TPU_GetInternalScratchOp : TPU_Op<"internal_scratch"> { + let arguments = (ins); + let results = (outs AnyMemRef:$result); + let assemblyFormat = [{ attr-dict `:` type($result) }]; +} + def TPU_PRNGSeed32Op : TPU_Op<"prng_set_seed_32"> { let arguments = (ins Variadic:$seeds); let results = (outs); @@ -695,6 +701,7 @@ def ApplyVectorLayoutPass : Pass<"tpu-apply-vector-layout", "::mlir::func::FuncO Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">, Option<"mxu_contracting_size", "mxu-contracting-size", "int", /*default=*/"128", "">, Option<"mxu_noncontracting_size", "mxu-noncontracting-size", "int", /*default=*/"128", "">, + Option<"max_sublanes_in_scratch", "max-sublanes-in-scratch", "int", /*default=*/"0", "">, ]; } diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index 038f32bca42a..dc5b68246e3f 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -58,7 +58,8 @@ std::unique_ptr> createInferVectorLayoutPass( std::unique_ptr> createApplyVectorLayoutPass( int hardware_generation = -1, int lane_count = 128, int sublane_count = 8, - int mxu_contracting_size = 128, int mxu_noncontracting_size = 128); + int mxu_contracting_size = 128, int mxu_noncontracting_size = 128, + int max_sublanes_in_scratch = 0); std::unique_ptr> createLogicalToPhysicalDeviceIdPass(int64_t total_devices); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 22ca5143f777..a05448462197 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -157,6 +158,36 @@ FailureOr maskOOB(RewriteContext &ctx, OpBuilder &builder, .getResult(); } +// Get the address of pre-allocated internal scratch space with requested shape. +// +// Arguments: +// shape: The shape of the requested scratch space. +// elem_ty: The type of the elements in the requested scratch space. +// +// Returns: +// A memref of the requested shape and type. +FailureOr getInternalScratch(RewriteContext &ctx, OpBuilder &builder, + Location loc, ArrayRef shape, + Type elem_ty) { + if (shape.empty()) { + return failure(); + } + if (shape.back() % ctx.target_shape[1] != 0) { + return failure(); + } + int sublane_count = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()) / + ctx.target_shape[1]; + if (sublane_count > ctx.max_sublanes_in_scratch) { + return failure(); + } + FAILUREOR_ASSIGN_OR_RETURN( + MemRefType scratch_ref_ty, + inferMemref(MemRefType::get(shape, elem_ty), ctx.hardware_generation)); + return builder.create(loc, scratch_ref_ty) + .getResult(); +} + // Models Numpy's np.repeat, repeating each element `repeats` times along the // specified axis. For example, if `src` is [1, 2], `axis` is 0 and `repeats` is // 3, this will return [1, 1, 1, 2, 2, 2]. @@ -5024,12 +5055,14 @@ struct ApplyVectorLayoutPass : public impl::ApplyVectorLayoutPassBase { ApplyVectorLayoutPass(int hardware_generation_, int lane_count_, int sublane_count_, int mxu_contracting_size_, - int mxu_noncontracting_size_) { + int mxu_noncontracting_size_, + int max_sublanes_in_scratch_) { hardware_generation = hardware_generation_; sublane_count = sublane_count_; lane_count = lane_count_; mxu_contracting_size = mxu_contracting_size_; mxu_noncontracting_size = mxu_noncontracting_size_; + max_sublanes_in_scratch = max_sublanes_in_scratch_; } void runOnOperation() override { // Fail if hardware_generation has not been set from the default value. @@ -5041,7 +5074,8 @@ struct ApplyVectorLayoutPass RewriteContext ctx{func, hardware_generation, {sublane_count, lane_count}, - {mxu_contracting_size, mxu_noncontracting_size}}; + {mxu_contracting_size, mxu_noncontracting_size}, + max_sublanes_in_scratch}; if (failed(applyLayoutFunc(ctx, func))) { signalPassFailure(); return; @@ -5051,10 +5085,11 @@ struct ApplyVectorLayoutPass std::unique_ptr> createApplyVectorLayoutPass( int hardware_generation, int lane_count, int sublane_count, - int mxu_contracting_size, int mxu_noncontracting_size) { + int mxu_contracting_size, int mxu_noncontracting_size, + int max_sublanes_in_scratch) { return std::make_unique( hardware_generation, lane_count, sublane_count, mxu_contracting_size, - mxu_noncontracting_size); + mxu_noncontracting_size, max_sublanes_in_scratch); } } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h index 75fb5e7904a1..547a8a00c10c 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h @@ -21,6 +21,7 @@ struct RewriteContext { const int hardware_generation; const std::array target_shape = {8, 128}; const std::array mxu_shape = {128, 128}; + const int max_sublanes_in_scratch = 0; MLIRContext *getMLIRContext() { return func.getContext(); } };