Skip to content

Commit

Permalink
[XLA:Mosaic] Add internal scratch VMEM
Browse files Browse the repository at this point in the history
- Make internal scratch size configurable.
- Pass the number of max sublanes allowed in scratch to apply-vector-layout pass.
- Create a helper function to fetch internal scratch VMEM address.

PiperOrigin-RevId: 644184896
  • Loading branch information
bythew3i authored and jax authors committed Jun 18, 2024
1 parent 701c63e commit ed4958c
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 5 deletions.
7 changes: 7 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,12 @@ def TPU_GetIterationBoundOp : TPU_Op<"iteration_bound"> {
let assemblyFormat = [{ $dim attr-dict `:` type($result) }];
}

def TPU_GetInternalScratchOp : TPU_Op<"internal_scratch"> {
let arguments = (ins);
let results = (outs AnyMemRef:$result);
let assemblyFormat = [{ attr-dict `:` type($result) }];
}

def TPU_PRNGSeed32Op : TPU_Op<"prng_set_seed_32"> {
let arguments = (ins Variadic<I32>:$seeds);
let results = (outs);
Expand Down Expand Up @@ -695,6 +701,7 @@ def ApplyVectorLayoutPass : Pass<"tpu-apply-vector-layout", "::mlir::func::FuncO
Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">,
Option<"mxu_contracting_size", "mxu-contracting-size", "int", /*default=*/"128", "">,
Option<"mxu_noncontracting_size", "mxu-noncontracting-size", "int", /*default=*/"128", "">,
Option<"max_sublanes_in_scratch", "max-sublanes-in-scratch", "int", /*default=*/"0", "">,
];
}

Expand Down
3 changes: 2 additions & 1 deletion jaxlib/mosaic/dialect/tpu/tpu_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(

std::unique_ptr<OperationPass<func::FuncOp>> createApplyVectorLayoutPass(
int hardware_generation = -1, int lane_count = 128, int sublane_count = 8,
int mxu_contracting_size = 128, int mxu_noncontracting_size = 128);
int mxu_contracting_size = 128, int mxu_noncontracting_size = 128,
int max_sublanes_in_scratch = 0);

std::unique_ptr<OperationPass<func::FuncOp>>
createLogicalToPhysicalDeviceIdPass(int64_t total_devices);
Expand Down
43 changes: 39 additions & 4 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <cstdlib>
#include <functional>
#include <memory>
#include <numeric>
#include <optional>
#include <tuple>
#include <utility>
Expand Down Expand Up @@ -157,6 +158,36 @@ FailureOr<Value> maskOOB(RewriteContext &ctx, OpBuilder &builder,
.getResult();
}

// Get the address of pre-allocated internal scratch space with requested shape.
//
// Arguments:
// shape: The shape of the requested scratch space.
// elem_ty: The type of the elements in the requested scratch space.
//
// Returns:
// A memref of the requested shape and type.
FailureOr<Value> getInternalScratch(RewriteContext &ctx, OpBuilder &builder,
Location loc, ArrayRef<int64_t> shape,
Type elem_ty) {
if (shape.empty()) {
return failure();
}
if (shape.back() % ctx.target_shape[1] != 0) {
return failure();
}
int sublane_count =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) /
ctx.target_shape[1];
if (sublane_count > ctx.max_sublanes_in_scratch) {
return failure();
}
FAILUREOR_ASSIGN_OR_RETURN(
MemRefType scratch_ref_ty,
inferMemref(MemRefType::get(shape, elem_ty), ctx.hardware_generation));
return builder.create<tpu::GetInternalScratchOp>(loc, scratch_ref_ty)
.getResult();
}

// Models Numpy's np.repeat, repeating each element `repeats` times along the
// specified axis. For example, if `src` is [1, 2], `axis` is 0 and `repeats` is
// 3, this will return [1, 1, 1, 2, 2, 2].
Expand Down Expand Up @@ -5024,12 +5055,14 @@ struct ApplyVectorLayoutPass
: public impl::ApplyVectorLayoutPassBase<ApplyVectorLayoutPass> {
ApplyVectorLayoutPass(int hardware_generation_, int lane_count_,
int sublane_count_, int mxu_contracting_size_,
int mxu_noncontracting_size_) {
int mxu_noncontracting_size_,
int max_sublanes_in_scratch_) {
hardware_generation = hardware_generation_;
sublane_count = sublane_count_;
lane_count = lane_count_;
mxu_contracting_size = mxu_contracting_size_;
mxu_noncontracting_size = mxu_noncontracting_size_;
max_sublanes_in_scratch = max_sublanes_in_scratch_;
}
void runOnOperation() override {
// Fail if hardware_generation has not been set from the default value.
Expand All @@ -5041,7 +5074,8 @@ struct ApplyVectorLayoutPass
RewriteContext ctx{func,
hardware_generation,
{sublane_count, lane_count},
{mxu_contracting_size, mxu_noncontracting_size}};
{mxu_contracting_size, mxu_noncontracting_size},
max_sublanes_in_scratch};
if (failed(applyLayoutFunc(ctx, func))) {
signalPassFailure();
return;
Expand All @@ -5051,10 +5085,11 @@ struct ApplyVectorLayoutPass

std::unique_ptr<OperationPass<func::FuncOp>> createApplyVectorLayoutPass(
int hardware_generation, int lane_count, int sublane_count,
int mxu_contracting_size, int mxu_noncontracting_size) {
int mxu_contracting_size, int mxu_noncontracting_size,
int max_sublanes_in_scratch) {
return std::make_unique<ApplyVectorLayoutPass>(
hardware_generation, lane_count, sublane_count, mxu_contracting_size,
mxu_noncontracting_size);
mxu_noncontracting_size, max_sublanes_in_scratch);
}

} // namespace mlir::tpu
1 change: 1 addition & 0 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct RewriteContext {
const int hardware_generation;
const std::array<int64_t, 2> target_shape = {8, 128};
const std::array<int64_t, 2> mxu_shape = {128, 128};
const int max_sublanes_in_scratch = 0;

MLIRContext *getMLIRContext() { return func.getContext(); }
};
Expand Down

0 comments on commit ed4958c

Please sign in to comment.