Skip to content

Commit

Permalink
Fix OMP fork arg handling (rust-lang#485)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 5, 2022
1 parent 9e48a51 commit e40a859
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 44 deletions.
34 changes: 1 addition & 33 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -3878,7 +3878,7 @@ class AdjointGenerator
}

assert(uncacheable_args_map.find(&call) != uncacheable_args_map.end());
const std::map<Argument *, bool> &uncacheable_argsAbove =
const std::map<Argument *, bool> &uncacheable_args =
uncacheable_args_map.find(&call)->second;

IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&call));
Expand All @@ -3902,38 +3902,6 @@ class AdjointGenerator
"could not derive underlying task contents from omp call");
}

std::map<Argument *, bool> uncacheable_args;
{
auto in_arg = call.getCalledFunction()->arg_begin();
auto pp_arg = task->arg_begin();

// Global.tid is cacheable
uncacheable_args[pp_arg] = false;
++pp_arg;
// Bound.tid is cacheable
uncacheable_args[pp_arg] = false;
++pp_arg;

// Ignore the first three args of init call
++in_arg;
++in_arg;
++in_arg;

for (; pp_arg != task->arg_end();) {
// If var-args then we may still have args even though outermost
// has no more
if (in_arg == call.getCalledFunction()->arg_end()) {
uncacheable_args[pp_arg] = true;
} else {
assert(uncacheable_argsAbove.find(in_arg) !=
uncacheable_argsAbove.end());
uncacheable_args[pp_arg] = uncacheable_argsAbove.find(in_arg)->second;
++in_arg;
}
++pp_arg;
}
}

auto called = task;
// bool modifyPrimal = true;

Expand Down
59 changes: 48 additions & 11 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -537,14 +537,15 @@ struct CacheAnalysis {
return {};
}

if (Fn->getName().startswith("MPI_") ||
Fn->getName().startswith("enzyme_wrapmpi$$"))
StringRef funcName = Fn->getName();

if (funcName.startswith("MPI_") || funcName.startswith("enzyme_wrapmpi$$"))
return {};

if (Fn->getName() == "__kmpc_for_static_init_4" ||
Fn->getName() == "__kmpc_for_static_init_4u" ||
Fn->getName() == "__kmpc_for_static_init_8" ||
Fn->getName() == "__kmpc_for_static_init_8u") {
if (funcName == "__kmpc_for_static_init_4" ||
funcName == "__kmpc_for_static_init_4u" ||
funcName == "__kmpc_for_static_init_8" ||
funcName == "__kmpc_for_static_init_8u") {
return {};
}

Expand Down Expand Up @@ -644,12 +645,48 @@ struct CacheAnalysis {

std::map<Argument *, bool> uncacheable_args;

auto arg = Fn->arg_begin();
for (unsigned i = 0; i < args.size(); ++i) {
uncacheable_args[arg] = !args_safe[i];
if (funcName == "__kmpc_fork_call") {
Value *op = callsite_op->getArgOperand(2);
Function *task = nullptr;
while (!(task = dyn_cast<Function>(op))) {
if (auto castinst = dyn_cast<ConstantExpr>(op))
if (castinst->isCast()) {
op = castinst->getOperand(0);
continue;
}
if (auto CI = dyn_cast<CastInst>(op)) {
op = CI->getOperand(0);
continue;
}
llvm::errs() << "op: " << *op << "\n";
assert(0 && "unknown fork call arg");
}

auto arg = task->arg_begin();

// Global.tid is cacheable
uncacheable_args[arg] = false;
++arg;
if (arg == Fn->arg_end()) {
break;
// Bound.tid is cacheable
uncacheable_args[arg] = false;
++arg;

// Ignore first three arguments of fork call
for (unsigned i = 3; i < args.size(); ++i) {
uncacheable_args[arg] = !args_safe[i];
++arg;
if (arg == Fn->arg_end()) {
break;
}
}
} else {
auto arg = Fn->arg_begin();
for (unsigned i = 0; i < args.size(); ++i) {
uncacheable_args[arg] = !args_safe[i];
++arg;
if (arg == Fn->arg_end()) {
break;
}
}
}

Expand Down
179 changes: 179 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/ompsqloopoutofplace.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
; RUN: if [ %llvmver -ge 9 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -adce -simplifycfg -S | FileCheck %s; fi

source_filename = "lulesh.cc"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

%struct.ident_t = type { i32, i32, i32, i32, i8* }

@0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
@1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 514, i32 0, i32 0, i8* getelementptr inbounds ([23 x i8], [23 x i8]* @0, i32 0, i32 0) }, align 8
@2 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 0, i8* getelementptr inbounds ([23 x i8], [23 x i8]* @0, i32 0, i32 0) }, align 8

; Function Attrs: norecurse nounwind uwtable mustprogress
define void @caller(double* %out, double* %dout, double* %in, double* %din) {
entry:
call void @_Z17__enzyme_autodiffPvS_S_m(i8* bitcast (void (double*, double*, i64)* @_ZL16LagrangeLeapFrogPdm to i8*), double* %out, double* %dout, double* %in, double* %din, i64 100)
ret void
}

declare dso_local void @_Z17__enzyme_autodiffPvS_S_m(i8*, double*, double*, double*, double*, i64)

; Function Attrs: inlinehint nounwind uwtable mustprogress
define internal void @_ZL16LagrangeLeapFrogPdm(double* noalias %out, double* noalias %in, i64 %length) #3 {
entry:
tail call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* nonnull @2, i32 2, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i64, double*, double*)* @.omp_outlined. to void (i32*, i32*, ...)*), i64 %length, double* %out, double* %in)
ret void
}

; Function Attrs: norecurse nounwind uwtable
define internal void @.omp_outlined.(i32* noalias nocapture readonly %.global_tid., i32* noalias nocapture readnone %.bound_tid., i64 %length, double* nocapture noalias %out, double* nocapture noalias %tmp) #4 {
entry:
%.omp.lb = alloca i64, align 8
%.omp.ub = alloca i64, align 8
%.omp.stride = alloca i64, align 8
%.omp.is_last = alloca i32, align 4
%sub4 = add i64 %length, -1
%cmp.not = icmp eq i64 %length, 0
br i1 %cmp.not, label %omp.precond.end, label %omp.precond.then

omp.precond.then: ; preds = %entry
%0 = bitcast i64* %.omp.lb to i8*
store i64 0, i64* %.omp.lb, align 8, !tbaa !3
%1 = bitcast i64* %.omp.ub to i8*
store i64 %sub4, i64* %.omp.ub, align 8, !tbaa !3
%2 = bitcast i64* %.omp.stride to i8*
store i64 1, i64* %.omp.stride, align 8, !tbaa !3
%3 = bitcast i32* %.omp.is_last to i8*
store i32 0, i32* %.omp.is_last, align 4, !tbaa !7
%4 = load i32, i32* %.global_tid., align 4, !tbaa !7
call void @__kmpc_for_static_init_8u(%struct.ident_t* nonnull @1, i32 %4, i32 34, i32* nonnull %.omp.is_last, i64* nonnull %.omp.lb, i64* nonnull %.omp.ub, i64* nonnull %.omp.stride, i64 1, i64 1)
%5 = load i64, i64* %.omp.ub, align 8, !tbaa !3
%cmp6 = icmp ugt i64 %5, %sub4
%cond = select i1 %cmp6, i64 %sub4, i64 %5
store i64 %cond, i64* %.omp.ub, align 8, !tbaa !3
%6 = load i64, i64* %.omp.lb, align 8, !tbaa !3
%add29 = add i64 %cond, 1
%cmp730 = icmp ult i64 %6, %add29
br i1 %cmp730, label %omp.inner.for.body, label %omp.loop.exit

omp.inner.for.body: ; preds = %omp.precond.then, %omp.inner.for.body
%.omp.iv.031 = phi i64 [ %add11, %omp.inner.for.body ], [ %6, %omp.precond.then ]
%arrayidx = getelementptr inbounds double, double* %tmp, i64 %.omp.iv.031
%7 = load double, double* %arrayidx, align 8, !tbaa !9
%call = call double @sqrt(double %7) #5
%outidx = getelementptr inbounds double, double* %out, i64 %.omp.iv.031
store double %call, double* %outidx, align 8, !tbaa !9
%add11 = add nuw i64 %.omp.iv.031, 1
%8 = load i64, i64* %.omp.ub, align 8, !tbaa !3
%add = add i64 %8, 1
%cmp7 = icmp ult i64 %add11, %add
br i1 %cmp7, label %omp.inner.for.body, label %omp.loop.exit

omp.loop.exit: ; preds = %omp.inner.for.body, %omp.precond.then
call void @__kmpc_for_static_fini(%struct.ident_t* nonnull @1, i32 %4)
br label %omp.precond.end

omp.precond.end: ; preds = %omp.loop.exit, %entry
ret void
}

; Function Attrs: nounwind
declare dso_local void @__kmpc_for_static_init_8u(%struct.ident_t*, i32, i32, i32*, i64*, i64*, i64*, i64, i64) local_unnamed_addr #5

; Function Attrs: nofree nounwind willreturn mustprogress
declare dso_local double @sqrt(double) local_unnamed_addr #6

; Function Attrs: nounwind
declare void @__kmpc_for_static_fini(%struct.ident_t*, i32) local_unnamed_addr #5

; Function Attrs: nounwind
declare !callback !11 void @__kmpc_fork_call(%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) local_unnamed_addr #5

attributes #0 = { norecurse nounwind uwtable }
attributes #1 = { argmemonly }

!llvm.module.flags = !{!0, !1}
!llvm.ident = !{!2}
!nvvm.annotations = !{}

!0 = !{i32 1, !"wchar_size", i32 4}
!1 = !{i32 7, !"uwtable", i32 1}
!2 = !{!"clang version 13.0.0 (git@github.com:llvm/llvm-project 619bfe8bd23f76b22f0a53fedafbfc8c97a15f12)"}
!3 = !{!4, !4, i64 0}
!4 = !{!"long", !5, i64 0}
!5 = !{!"omnipotent char", !6, i64 0}
!6 = !{!"Simple C++ TBAA"}
!7 = !{!8, !8, i64 0}
!8 = !{!"int", !5, i64 0}
!9 = !{!10, !10, i64 0}
!10 = !{!"double", !5, i64 0}
!11 = !{!12}
!12 = !{i64 2, i64 -1, i64 -1, i1 true}

; This should not cache and instead reload from %tmp

; CHECK: define internal void @diffe.omp_outlined.(i32* noalias nocapture readonly %.global_tid., i32* noalias nocapture readnone %.bound_tid., i64 %length, double* noalias nocapture %out, double* nocapture %"out'", double* noalias nocapture %tmp, double* nocapture %"tmp'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %.omp.lb_smpl = alloca i64
; CHECK-NEXT: %.omp.ub_smpl = alloca i64
; CHECK-NEXT: %.omp.stride_smpl = alloca i64
; CHECK-NEXT: %.omp.is_last = alloca i32
; CHECK-NEXT: %sub4 = add i64 %length, -1
; CHECK-NEXT: %cmp.not = icmp eq i64 %length, 0
; CHECK-NEXT: br i1 %cmp.not, label %invertentry, label %omp.precond.then

; CHECK: omp.precond.then: ; preds = %entry
; CHECK-NEXT: store i32 0, i32* %.omp.is_last
; CHECK-NEXT: %0 = load i32, i32* %.global_tid.
; CHECK-NEXT: store i64 0, i64* %.omp.lb_smpl
; CHECK-NEXT: store i64 %sub4, i64* %.omp.ub_smpl
; CHECK-NEXT: store i64 1, i64* %.omp.stride_smpl
; CHECK-NEXT: call void @__kmpc_for_static_init_8u(%struct.ident_t* nonnull @1, i32 %0, i32 34, i32* nonnull %.omp.is_last, i64* nocapture nonnull %.omp.lb_smpl, i64* nocapture nonnull %.omp.ub_smpl, i64* nocapture nonnull %.omp.stride_smpl, i64 1, i64 1) #0
; CHECK-NEXT: %_unwrap8 = load i64, i64* %.omp.lb_smpl
; CHECK-NEXT: %_unwrap9 = load i64, i64* %.omp.ub_smpl
; CHECK-NEXT: %cmp6_unwrap10 = icmp ugt i64 %_unwrap9, %sub4
; CHECK-NEXT: %cond_unwrap11 = select i1 %cmp6_unwrap10, i64 %sub4, i64 %_unwrap9
; CHECK-NEXT: %add29_unwrap = add i64 %cond_unwrap11, 1
; CHECK-NEXT: %cmp730_unwrap = icmp ult i64 %_unwrap8, %add29_unwrap
; CHECK-NEXT: br i1 %cmp730_unwrap, label %invertomp.loop.exit.loopexit, label %invertomp.precond.then

; CHECK: invertentry: ; preds = %entry, %invertomp.precond.then
; CHECK-NEXT: ret void

; CHECK: invertomp.precond.then: ; preds = %invertomp.inner.for.body, %omp.precond.then
; CHECK-NEXT: %_unwrap = load i32, i32* %.global_tid., align 4, !tbaa !7, !invariant.group !13
; CHECK-NEXT: call void @__kmpc_for_static_fini(%struct.ident_t* @1, i32 %_unwrap)
; CHECK-NEXT: br label %invertentry

; CHECK: invertomp.inner.for.body: ; preds = %invertomp.loop.exit.loopexit, %incinvertomp.inner.for.body
; CHECK-NEXT: %"iv'ac.0" = phi i64 [ %_unwrap7, %invertomp.loop.exit.loopexit ], [ %9, %incinvertomp.inner.for.body ]
; CHECK-NEXT: %_unwrap2 = load i64, i64* %.omp.lb_smpl
; CHECK-NEXT: %_unwrap3 = add i64 {{((%_unwrap2, %"iv'ac.0")|%"iv'ac.0", %_unwrap2)}}
; CHECK-NEXT: %"outidx'ipg_unwrap" = getelementptr inbounds double, double* %"out'", i64 %_unwrap3
; CHECK-NEXT: %1 = load double, double* %"outidx'ipg_unwrap", align 8
; CHECK-NEXT: store double 0.000000e+00, double* %"outidx'ipg_unwrap", align 8
; CHECK-NEXT: %arrayidx_unwrap = getelementptr inbounds double, double* %tmp, i64 %_unwrap3
; CHECK-NEXT: %_unwrap4 = load double, double* %arrayidx_unwrap, align 8, !tbaa !9, !invariant.group !16
; CHECK-NEXT: %2 = call fast double @llvm.sqrt.f64(double %_unwrap4)
; CHECK-NEXT: %3 = fmul fast double 5.000000e-01, %1
; CHECK-NEXT: %4 = fdiv fast double %3, %2
; CHECK-NEXT: %5 = fcmp fast oeq double %_unwrap4, 0.000000e+00
; CHECK-NEXT: %6 = select fast i1 %5, double 0.000000e+00, double %4
; CHECK-NEXT: %"arrayidx'ipg_unwrap" = getelementptr inbounds double, double* %"tmp'", i64 %_unwrap3
; CHECK-NEXT: %7 = atomicrmw fadd double* %"arrayidx'ipg_unwrap", double %6 monotonic
; CHECK-NEXT: %8 = icmp eq i64 %"iv'ac.0", 0
; CHECK-NEXT: br i1 %8, label %invertomp.precond.then, label %incinvertomp.inner.for.body

; CHECK: incinvertomp.inner.for.body: ; preds = %invertomp.inner.for.body
; CHECK-NEXT: %9 = add nsw i64 %"iv'ac.0", -1
; CHECK-NEXT: br label %invertomp.inner.for.body

; CHECK: invertomp.loop.exit.loopexit: ; preds = %omp.precond.then
; CHECK-NEXT: %_unwrap5 = load i64, i64* %.omp.ub_smpl
; CHECK-NEXT: %cmp6_unwrap = icmp ugt i64 %_unwrap5, %sub4
; CHECK-NEXT: %cond_unwrap = select i1 %cmp6_unwrap, i64 %sub4, i64 %_unwrap5
; CHECK-NEXT: %_unwrap6 = load i64, i64* %.omp.lb_smpl
; CHECK-NEXT: %_unwrap7 = sub i64 %cond_unwrap, %_unwrap6
; CHECK-NEXT: br label %invertomp.inner.for.body
; CHECK-NEXT: }

0 comments on commit e40a859

Please sign in to comment.