Skip to content

Commit

Permalink
Use fibers to guarantee stack size on Windows (#5873)
Browse files Browse the repository at this point in the history
* Use fibers for lowering.

* Move fibers to Util

* Wrap compile_func call in call_with_stack_requirement

* Rename call_with_stack_requirement -> run_with_large_stack

* Appease clang_format

* Add exception handling to run_with_large_stack

* clang-format

* Fix 32-bit?

* Fix error wording for Makefile

* Improve naming in run_with_large_stack
  • Loading branch information
alexreinking committed Apr 2, 2021
1 parent 42092e3 commit 59a04e4
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 23 deletions.
2 changes: 1 addition & 1 deletion cmake/HalideTestHelpers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,4 @@ function(tests)
endforeach ()

set(TEST_NAMES "${TEST_NAMES}" PARENT_SCOPE)
endfunction(tests)
endfunction()
4 changes: 3 additions & 1 deletion src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,9 @@ std::unique_ptr<llvm::Module> CodeGen_LLVM::compile(const Module &input) {
for (const auto &f : input.functions()) {
const auto names = get_mangled_names(f, get_target());

compile_func(f, names.simple_name, names.extern_name);
run_with_large_stack([&]() {
compile_func(f, names.simple_name, names.extern_name);
});

// If the Func is externally visible, also create the argv wrapper and metadata.
// (useful for calling from JIT and other machine interfaces).
Expand Down
39 changes: 24 additions & 15 deletions src/Lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,23 +98,17 @@ class LoweringLogger {
}
};

} // namespace

Module lower(const vector<Function> &output_funcs,
const string &pipeline_name,
const Target &t,
const vector<Argument> &args,
const LinkageType linkage_type,
const vector<Stmt> &requirements,
bool trace_pipeline,
const vector<IRMutator *> &custom_passes) {
void lower_impl(const vector<Function> &output_funcs,
const string &pipeline_name,
const Target &t,
const vector<Argument> &args,
const LinkageType linkage_type,
const vector<Stmt> &requirements,
bool trace_pipeline,
const vector<IRMutator *> &custom_passes,
Module &result_module) {
auto time_start = std::chrono::high_resolution_clock::now();

std::vector<std::string> namespaces;
std::string simple_pipeline_name = extract_namespaces(pipeline_name, namespaces);

Module result_module(simple_pipeline_name, t);

// Compute an environment
map<string, Function> env;
for (const Function &f : output_funcs) {
Expand Down Expand Up @@ -524,7 +518,22 @@ Module lower(const vector<Function> &output_funcs,
std::chrono::duration<double> diff = time_end - time_start;
logger->record_compilation_time(CompilerLogger::Phase::HalideLowering, diff.count());
}
}

} // namespace

Module lower(const vector<Function> &output_funcs,
const string &pipeline_name,
const Target &t,
const vector<Argument> &args,
const LinkageType linkage_type,
const vector<Stmt> &requirements,
bool trace_pipeline,
const vector<IRMutator *> &custom_passes) {
Module result_module{extract_namespaces(pipeline_name), t};
run_with_large_stack([&]() {
lower_impl(output_funcs, pipeline_name, t, args, linkage_type, requirements, trace_pipeline, custom_passes, result_module);
});
return result_module;
}

Expand Down
83 changes: 83 additions & 0 deletions src/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,11 @@ std::string extract_namespaces(const std::string &name, std::vector<std::string>
return result;
}

std::string extract_namespaces(const std::string &name) {
std::vector<std::string> unused;
return extract_namespaces(name, unused);
}

bool file_exists(const std::string &name) {
#ifdef _MSC_VER
return _access(name.c_str(), 0) == 0;
Expand Down Expand Up @@ -572,6 +577,84 @@ int get_llvm_version() {
return LLVM_VERSION;
}

#ifdef _WIN32

namespace {

struct GenericFiberArgs {
const std::function<void()> &run;
LPVOID main_fiber;
#ifdef HALIDE_WITH_EXCEPTIONS
std::exception_ptr exception = nullptr; // NOLINT - clang-tidy complains this isn't thrown
#endif
};

void WINAPI generic_fiber_entry_point(LPVOID argument) {
auto *action = reinterpret_cast<GenericFiberArgs *>(argument);
#ifdef HALIDE_WITH_EXCEPTIONS
try {
#endif
action->run();
#ifdef HALIDE_WITH_EXCEPTIONS
} catch (...) {
action->exception = std::current_exception();
}
#endif
SwitchToFiber(action->main_fiber);
}

} // namespace

#endif

void run_with_large_stack(const std::function<void()> &action) {
#if _WIN32
constexpr SIZE_T required_stack = 8 * 1024 * 1024;

// Only exists for its address, which is used to compute remaining stack space.
ULONG_PTR approx_stack_pos;

ULONG_PTR stack_low, stack_high;
GetCurrentThreadStackLimits(&stack_low, &stack_high);
ptrdiff_t stack_remaining = (char *)&approx_stack_pos - (char *)stack_low;

if (stack_remaining < required_stack) {
debug(1) << "Insufficient stack space (" << stack_remaining << " bytes). Switching to fiber with " << required_stack << "-byte stack.\n";

auto was_a_fiber = IsThreadAFiber();

auto *main_fiber = was_a_fiber ? GetCurrentFiber() : ConvertThreadToFiber(nullptr);
internal_assert(main_fiber) << "ConvertThreadToFiber failed with code: " << GetLastError() << "\n";

GenericFiberArgs fiber_args{action, main_fiber};
auto *lower_fiber = CreateFiber(required_stack, generic_fiber_entry_point, &fiber_args);
internal_assert(lower_fiber) << "CreateFiber failed with code: " << GetLastError() << "\n";

SwitchToFiber(lower_fiber);
DeleteFiber(lower_fiber);

debug(1) << "Returned from fiber.\n";

#ifdef HALIDE_WITH_EXCEPTIONS
if (fiber_args.exception) {
debug(1) << "Fiber threw exception. Rethrowing...\n";
std::rethrow_exception(fiber_args.exception);
}
#endif

if (!was_a_fiber) {
BOOL success = ConvertFiberToThread();
internal_assert(success) << "ConvertFiberToThread failed with code: " << GetLastError() << "\n";
}

return;
}

#endif

action();
}

} // namespace Internal

void load_plugin(const std::string &lib_name) {
Expand Down
14 changes: 9 additions & 5 deletions src/Util.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include <cstdint>
#include <cstring>
#include <functional>
#include <limits>
#include <string>
#include <utility>
Expand Down Expand Up @@ -44,11 +45,6 @@
#define HALIDE_NO_USER_CODE_INLINE HALIDE_NEVER_INLINE
#endif

// On windows, Halide needs a larger stack than the default MSVC provides
#ifdef _MSC_VER
#pragma comment(linker, "/STACK:8388608,1048576")
#endif

namespace Halide {

/** Load a plugin in the form of a dynamic library (e.g. for custom autoschedulers).
Expand Down Expand Up @@ -212,6 +208,9 @@ struct all_are_convertible : meta_and<std::is_convertible<Args, To>...> {};
/** Returns base name and fills in namespaces, outermost one first in vector. */
std::string extract_namespaces(const std::string &name, std::vector<std::string> &namespaces);

/** Overload that returns base name only */
std::string extract_namespaces(const std::string &name);

struct FileStat {
uint64_t file_size;
uint32_t mod_time; // Unix epoch time
Expand Down Expand Up @@ -466,6 +465,11 @@ std::string c_print_name(const std::string &name);
* of Halide tests. */
int get_llvm_version();

/** Call the given action in a platform-specific context that provides at least
* 8MB of stack space. Currently only has any effect on Windows where it uses
* a Fiber. */
void run_with_large_stack(const std::function<void()> &action);

} // namespace Internal
} // namespace Halide

Expand Down
1 change: 1 addition & 0 deletions test/error/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ tests(GROUPS error
reuse_var_in_schedule.cpp
reused_args.cpp
rfactor_inner_dim_non_commutative.cpp
run_with_large_stack_throws.cpp
specialize_fail.cpp
split_inner_wrong_tail_strategy.cpp
thread_id_outside_block_id.cpp
Expand Down
16 changes: 16 additions & 0 deletions test/error/run_with_large_stack_throws.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include "Halide.h"
#include <iostream>

int main() {
try {
Halide::Internal::run_with_large_stack([]() {
throw Halide::RuntimeError("Error from run_with_large_stack");
});
} catch (const Halide::RuntimeError &ex) {
std::cerr << ex.what() << "\n";
return 1;
}

std::cout << "Success!\n";
return 0;
}
2 changes: 1 addition & 1 deletion tutorial/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ add_dependencies(lesson_15_targets

##
add_test(NAME tutorial_lesson_15_build_gens
COMMAND ${CMAKE_COMMAND} --build ${CMAKE_BINARY_DIR} --target lesson_15_targets)
COMMAND ${CMAKE_COMMAND} --build ${CMAKE_BINARY_DIR} --target lesson_15_targets --config $<CONFIG>)
set_tests_properties(tutorial_lesson_15_build_gens PROPERTIES
LABELS tutorial
FIXTURES_SETUP tutorial_lesson_15)
Expand Down

0 comments on commit 59a04e4

Please sign in to comment.