Skip to content

Commit

Permalink
Selective cachereads (#21)
Browse files Browse the repository at this point in the history
* Enable cache reads by default, which is needed for correctness. 
* Selectively omit caching for reads whose value does not change after the load instruction. Loads that are modified after the load instruction are called "uncacheable" in the code.
* Propagate the uncacheable status of pointer arguments to calls.
* readwriteread C test illustrates behavior of code.
  • Loading branch information
timkaler authored and wsmoses committed May 17, 2021
1 parent 189a8ff commit 5bf2105
Show file tree
Hide file tree
Showing 21 changed files with 639 additions and 169 deletions.
3 changes: 2 additions & 1 deletion enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {//, Lo

bool differentialReturn = cast<Function>(fn)->getReturnType()->isFPOrFPVectorTy();

auto newFunc = CreatePrimalAndGradient(cast<Function>(fn), constants, TLI, AA, /*should return*/false, differentialReturn, /*topLevel*/true, /*addedType*/nullptr);//, LI, DT);
std::set<unsigned> volatile_args;
auto newFunc = CreatePrimalAndGradient(cast<Function>(fn), constants, TLI, AA, /*should return*/false, differentialReturn, /*topLevel*/true, /*addedType*/nullptr, volatile_args);//, LI, DT);

if (differentialReturn)
args.push_back(ConstantFP::get(cast<Function>(fn)->getReturnType(), 1.0));
Expand Down
447 changes: 411 additions & 36 deletions enzyme/Enzyme/EnzymeLogic.cpp

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion enzyme/Enzyme/EnzymeLogic.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ extern llvm::cl::opt<bool> enzyme_print;
//! return structtype if recursive function
std::pair<llvm::Function*,llvm::StructType*> CreateAugmentedPrimal(llvm::Function* todiff, llvm::AAResults &AA, const std::set<unsigned>& constant_args, llvm::TargetLibraryInfo &TLI, bool differentialReturn);

llvm::Function* CreatePrimalAndGradient(llvm::Function* todiff, const std::set<unsigned>& constant_args, llvm::TargetLibraryInfo &TLI, llvm::AAResults &AA, bool returnValue, bool differentialReturn, bool topLevel, llvm::Type* additionalArg);
llvm::Function* CreatePrimalAndGradient(llvm::Function* todiff, const std::set<unsigned>& constant_args, llvm::TargetLibraryInfo &TLI, llvm::AAResults &AA, bool returnValue, bool differentialReturn, bool topLevel, llvm::Type* additionalArg, std::set<unsigned> volatile_args);

#endif
32 changes: 23 additions & 9 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,13 @@ PHINode* canonicalizeIVs(fake::SCEVExpander &e, Type *Ty, Loop *L, DominatorTree

Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI) {
static std::map<Function*,Function*> cache;
if (cache.find(F) != cache.end()) return cache[F];

static std::map<Function*, BasicAAResult*> cache_AA;
llvm::errs() << "Before cache lookup for " << F->getName() << "\n";
if (cache.find(F) != cache.end()) {
AA.addAAResult(*(cache_AA[F]));
return cache[F];
}
llvm::errs() << "Did not do cache lookup for " << F->getName() << "\n";
Function *NewF = Function::Create(F->getFunctionType(), F->getLinkage(), "preprocess_" + F->getName(), F->getParent());

ValueToValueMapTy VMap;
Expand Down Expand Up @@ -439,7 +444,7 @@ Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI)
FunctionAnalysisManager AM;
AM.registerPass([] { return AAManager(); });
AM.registerPass([] { return ScalarEvolutionAnalysis(); });
AM.registerPass([] { return AssumptionAnalysis(); });
//AM.registerPass([] { return AssumptionAnalysis(); });
AM.registerPass([] { return TargetLibraryAnalysis(); });
AM.registerPass([] { return TargetIRAnalysis(); });
AM.registerPass([] { return LoopAnalysis(); });
Expand All @@ -458,13 +463,22 @@ Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI)
MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(AM); });

//Alias analysis is necessary to ensure can query whether we can move a forward pass function
BasicAA ba;
auto baa = new BasicAAResult(ba.run(*NewF, AM));
//BasicAA ba;
//auto baa = new BasicAAResult(ba.run(*NewF, AM));
AssumptionCache* AC = new AssumptionCache(*NewF);
TargetLibraryInfo* TLI = new TargetLibraryInfo(AM.getResult<TargetLibraryAnalysis>(*NewF));
auto baa = new BasicAAResult(NewF->getParent()->getDataLayout(),
*NewF,
*TLI,
*AC,
&AM.getResult<DominatorTreeAnalysis>(*NewF),
AM.getCachedResult<LoopAnalysis>(*NewF),
AM.getCachedResult<PhiValuesAnalysis>(*NewF));
cache_AA[F] = baa;
AA.addAAResult(*baa);

ScopedNoAliasAA sa;
auto saa = new ScopedNoAliasAAResult(sa.run(*NewF, AM));
AA.addAAResult(*saa);
//ScopedNoAliasAA sa;
//auto saa = new ScopedNoAliasAAResult(sa.run(*NewF, AM));
//AA.addAAResult(*saa);

}

Expand Down
13 changes: 8 additions & 5 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,8 @@ Value* GradientUtils::invertPointerM(Value* val, IRBuilder<>& BuilderM) {
return invertedPointers[val] = cs;
} else if (auto fn = dyn_cast<Function>(val)) {
//! Todo allow tape propagation
auto newf = CreatePrimalAndGradient(fn, /*constant_args*/{}, TLI, AA, /*returnValue*/false, /*differentialReturn*/fn->getReturnType()->isFPOrFPVectorTy(), /*topLevel*/false, /*additionalArg*/nullptr);
std::set<unsigned> uncacheable_args;
auto newf = CreatePrimalAndGradient(fn, /*constant_args*/{}, TLI, AA, /*returnValue*/false, /*differentialReturn*/fn->getReturnType()->isFPOrFPVectorTy(), /*topLevel*/false, /*additionalArg*/nullptr, uncacheable_args);
return BuilderM.CreatePointerCast(newf, fn->getType());
} else if (auto arg = dyn_cast<CastInst>(val)) {
auto result = BuilderM.CreateCast(arg->getOpcode(), invertPointerM(arg->getOperand(0), BuilderM), arg->getDestTy(), arg->getName()+"'ipc");
Expand Down Expand Up @@ -824,10 +825,12 @@ Value* GradientUtils::lookupM(Value* val, IRBuilder<>& BuilderM) {
}
}

if (!shouldRecompute(inst, available)) {
auto op = unwrapM(inst, BuilderM, available, /*lookupIfAble*/true);
assert(op);
return op;
if (!(*(this->can_modref_map))[inst]) {
if (!shouldRecompute(inst, available)) {
auto op = unwrapM(inst, BuilderM, available, /*lookupIfAble*/true);
assert(op);
return op;
}
}
/*
if (!inLoop) {
Expand Down
5 changes: 4 additions & 1 deletion enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ class GradientUtils {
ValueToValueMapTy scopeFrees;
ValueToValueMapTy originalToNewFn;

std::map<Instruction*, bool>* can_modref_map;


Value* getNewFromOriginal(Value* originst) {
assert(originst);
auto f = originalToNewFn.find(originst);
Expand Down Expand Up @@ -507,7 +510,7 @@ class GradientUtils {
}
assert(lastScopeAlloc.find(malloc) == lastScopeAlloc.end());
cast<Instruction>(malloc)->replaceAllUsesWith(ret);
auto n = malloc->getName();
std::string n = malloc->getName().str();
erase(cast<Instruction>(malloc));
ret->setName(n);
}
Expand Down
4 changes: 2 additions & 2 deletions enzyme/functional_tests_c/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ OBJ := $(wildcard *.c)

all: $(patsubst %.c,build/%-enzyme0,$(OBJ)) $(patsubst %.c,build/%-enzyme1,$(OBJ)) $(patsubst %.c,build/%-enzyme2,$(OBJ)) $(patsubst %.c,build/%-enzyme3,$(OBJ))

POST_ENZYME_FLAGS := -mem2reg -sroa -adce -simplifycfg -enzyme_cachereads=true
POST_ENZYME_FLAGS := -mem2reg -sroa -adce -simplifycfg

#all: $(patsubst %.c,build/%-enzyme1,$(OBJ)) $(patsubst %.c,build/%-enzyme2,$(OBJ)) $(patsubst %.c,build/%-enzyme3,$(OBJ))
#clean:
Expand All @@ -31,7 +31,7 @@ POST_ENZYME_FLAGS := -mem2reg -sroa -adce -simplifycfg -enzyme_cachereads=true

#EXTRA_FLAGS = -indvars -loop-simplify -loop-rotate

# NOTE(TFK): Optimization level 0 is broken right now.
# /efs/home/tfk/valgrind-3.12.0/vg-in-place
build/%-enzyme0: %.c
@./setup.sh $(CLANG_BIN_PATH)/clang -std=c11 -O1 $(patsubst %.c,%,$<).c -S -emit-llvm -o $@.ll
@./setup.sh $(CLANG_BIN_PATH)/opt $@.ll $(EXTRA_FLAGS) -load=$(ENZYME_PLUGIN) -enzyme $(POST_ENZYME_FLAGS) -o $@.bc
Expand Down
39 changes: 6 additions & 33 deletions enzyme/functional_tests_c/insertsort_sum.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@ float* unsorted_array_init(int N) {
return arr;
}

// sums the first half of a sorted array.
void insertsort_sum (float* array, int N, float* ret) {
void insertsort_sum (float*__restrict array, int N, float*__restrict ret) {
float sum = 0;
//qsort(array, N, sizeof(float), cmp);

for (int i = 1; i < N; i++) {
int j = i;
Expand All @@ -31,30 +29,16 @@ void insertsort_sum (float* array, int N, float* ret) {
}
}


for (int i = 0; i < N/2; i++) {
printf("Val: %f\n", array[i]);
//printf("Val: %f\n", array[i]);
sum += array[i];
}

*ret = sum;
}




int main(int argc, char** argv) {



float a = 2.0;
float b = 3.0;



float da = 0;
float db = 0;


float ret = 0;
float dret = 1.0;

Expand All @@ -71,18 +55,15 @@ int main(int argc, char** argv) {
printf("%d:%f\n", i, array[i]);
}

//insertsort_sum(array, N, &ret);
__builtin_autodiff(insertsort_sum, array, d_array, N, &ret, &dret);

printf("The total sum is %f\n", ret);

printf("Array after sorting:\n");
for (int i = 0; i < N; i++) {
printf("%d:%f\n", i, array[i]);
}


printf("The total sum is %f\n", ret);

__builtin_autodiff(insertsort_sum, array, d_array, N, &ret, &dret);

for (int i = 0; i < N; i++) {
printf("Diffe for index %d is %f\n", i, d_array[i]);
if (i%2 == 0) {
Expand All @@ -91,13 +72,5 @@ int main(int argc, char** argv) {
assert(d_array[i] == 1.0);
}
}

//__builtin_autodiff(compute_loops, &a, &da, &b, &db, &ret, &dret);


//assert(da == 100*1.0f);
//assert(db == 100*1.0f);

//printf("hello! %f, res2 %f, da: %f, db: %f\n", ret, ret, da,db);
return 0;
}
4 changes: 2 additions & 2 deletions enzyme/functional_tests_c/insertsort_sum_alt.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void insertion_sort_inner(float* array, int i) {
}

// sums the first half of a sorted array.
void insertsort_sum (float* array, int N, float* ret) {
void insertsort_sum (float*__restrict array, int N, float*__restrict ret) {
float sum = 0;
//qsort(array, N, sizeof(float), cmp);

Expand All @@ -45,7 +45,7 @@ void insertsort_sum (float* array, int N, float* ret) {


for (int i = 0; i < N/2; i++) {
printf("Val: %f\n", array[i]);
//printf("Val: %f\n", array[i]);
sum += array[i];
}
*ret = sum;
Expand Down
46 changes: 46 additions & 0 deletions enzyme/functional_tests_c/readwriteread.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <assert.h>
#define __builtin_autodiff __enzyme_autodiff
double __enzyme_autodiff(void*, ...);

double f_read(double* x) {
double product = (*x) * (*x);
return product;
}

void g_write(double* x, double product) {
*x = (*x) * product;
}

double h_read(double* x) {
return *x;
}

double readwriteread_helper(double* x) {
double product = f_read(x);
g_write(x, product);
double ret = h_read(x);
return ret;
}

void readwriteread(double*__restrict x, double*__restrict ret) {
*ret = readwriteread_helper(x);
}

int main(int argc, char** argv) {
double ret = 0;
double dret = 1.0;
double* x = (double*) malloc(sizeof(double));
double* dx = (double*) malloc(sizeof(double));
*x = 2.0;
*dx = 0.0;

__builtin_autodiff(readwriteread, x, dx, &ret, &dret);


printf("dx is %f ret is %f\n", *dx, ret);
assert(*dx == 3*2.0*2.0);
return 0;
}
4 changes: 2 additions & 2 deletions enzyme/functional_tests_c/setup.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/bin/bash

# NOTE(TFK): Uncomment for local testing.
export CLANG_BIN_PATH=./../../build-dbg/bin
export ENZYME_PLUGIN=./../mkdebug/Enzyme/LLVMEnzyme-7.so
export CLANG_BIN_PATH=./../../llvm/build/bin/
export ENZYME_PLUGIN=./../build/Enzyme/LLVMEnzyme-7.so

mkdir -p build
$@
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
; RUN: cd %desired_wd
; RUN: make clean-readwriteread-enzyme0 ENZYME_PLUGIN=%loadEnzyme
; RUN: make build/readwriteread-enzyme0 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath
; RUN: build/readwriteread-enzyme0
; RUN: make clean-readwriteread-enzyme0 ENZYME_PLUGIN=%loadEnzyme

Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
; RUN: cd %desired_wd
; RUN: make clean-readwriteread-enzyme1 ENZYME_PLUGIN=%loadEnzyme
; RUN: make build/readwriteread-enzyme1 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath
; RUN: build/readwriteread-enzyme1
; RUN: make clean-readwriteread-enzyme1 ENZYME_PLUGIN=%loadEnzyme

Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
; RUN: cd %desired_wd
; RUN: make clean-readwriteread-enzyme2 ENZYME_PLUGIN=%loadEnzyme
; RUN: make build/readwriteread-enzyme2 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath
; RUN: build/readwriteread-enzyme2
; RUN: make clean-readwriteread-enzyme2 ENZYME_PLUGIN=%loadEnzyme

Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
; RUN: cd %desired_wd
; RUN: make clean-readwriteread-enzyme3 ENZYME_PLUGIN=%loadEnzyme
; RUN: make build/readwriteread-enzyme3 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath
; RUN: build/readwriteread-enzyme3
; RUN: make clean-readwriteread-enzyme3 ENZYME_PLUGIN=%loadEnzyme

30 changes: 18 additions & 12 deletions enzyme/test/Enzyme/badcall.ll
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@ attributes #1 = { noinline nounwind uwtable }

; CHECK: define internal {{(dso_local )?}}{} @diffef(double* nocapture %x, double* %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = call { { {} } } @augmented_subf(double* %x, double* %"x'")
; CHECK-NEXT: store double 2.000000e+00, double* %x, align 8
; CHECK-NEXT: store double 0.000000e+00, double* %"x'"
; CHECK-NEXT: %1 = call {} @diffesubf(double* nonnull %x, double* %"x'", { {} } undef)
; CHECK-NEXT: ret {} undef
; CHECK-NEXT: %0 = call { { {}, double } } @augmented_subf(double* %x, double* %"x'")
; CHECK-NEXT: %1 = extractvalue { { {}, double } } %0, 0
; CHECK-NEXT: store double 2.000000e+00, double* %x, align 8
; CHECK-NEXT: store double 0.000000e+00, double* %"x'"
; CHECK-NEXT: %2 = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, double } %1)
; CHECK-NEXT: ret {} undef
; CHECK-NEXT: }

; CHECK: define internal {{(dso_local )?}}{ {} } @augmented_metasubf(double* nocapture %x, double* %"x'")
Expand All @@ -56,16 +57,21 @@ attributes #1 = { noinline nounwind uwtable }
; CHECK-NEXT: ret { {} } undef
; CHECK-NEXT: }

; CHECK: define internal {{(dso_local )?}}{ { {} } } @augmented_subf(double* nocapture %x, double* %"x'")
; CHECK: define internal {{(dso_local )?}}{ { {}, double } } @augmented_subf(double* nocapture %x, double* %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = load double, double* %x, align 8
; CHECK-NEXT: %mul = fmul fast double %0, 2.000000e+00
; CHECK-NEXT: store double %mul, double* %x, align 8
; CHECK-NEXT: %1 = call { {} } @augmented_metasubf(double* %x, double* %"x'")
; CHECK-NEXT: ret { { {} } } undef
; CHECK-NEXT: %0 = alloca { { {}, double } }
; CHECK-NEXT: %1 = getelementptr { { {}, double } }, { { {}, double } }* %0, i32 0, i32 0
; CHECK-NEXT: %2 = load double, double* %x, align 8
; CHECK-NEXT: %3 = getelementptr { {}, double }, { {}, double }* %1, i32 0, i32 1
; CHECK-NEXT: store double %2, double* %3
; CHECK-NEXT: %mul = fmul fast double %2, 2.000000e+00
; CHECK-NEXT: store double %mul, double* %x, align 8
; CHECK-NEXT: %4 = call { {} } @augmented_metasubf(double* %x, double* %"x'")
; CHECK-NEXT: %5 = load { { {}, double } }, { { {}, double } }* %0
; CHECK-NEXT: ret { { {}, double } } %5
; CHECK-NEXT: }

; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {} } %tapeArg)
; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {}, double } %tapeArg)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = call {} @diffemetasubf(double* %x, double* %"x'", {} undef)
; CHECK-NEXT: %1 = load double, double* %"x'"
Expand Down
Loading

0 comments on commit 5bf2105

Please sign in to comment.