Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Revert "[Dy2Stat] Refactor ExecutorCache logic and pre-support BuildStrategy for pass (#34181)" (#34348)" #34384

Merged
merged 1 commit into from
Jul 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 27 additions & 33 deletions paddle/fluid/framework/executor_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/fluid/framework/executor_cache.h"
#include "paddle/fluid/framework/op_info.h"

namespace paddle {
namespace framework {
Expand All @@ -25,11 +26,11 @@ namespace framework {

namespace details {

static ExecutionStrategy GetExecutionStrategy(
const ExecutorInfoCache::CacheKey &cache_key) {
static ExecutionStrategy GetExecutionStrategy(const platform::Place &place) {
framework::ExecutionStrategy execution_strategy;

switch (cache_key.device_type_) {
auto device_type = platform::Place2DeviceType(place);
switch (device_type) {
case platform::DeviceType::CPU: {
execution_strategy.num_threads_ = 2;
break;
Expand All @@ -46,9 +47,9 @@ static ExecutionStrategy GetExecutionStrategy(
}
default:
PADDLE_THROW(platform::errors::Unavailable("Unsupported Device type %d.",
cache_key.device_type_));
device_type));
}
execution_strategy.use_device_ = cache_key.device_type_;
execution_strategy.use_device_ = device_type;

return execution_strategy;
}
Expand Down Expand Up @@ -136,58 +137,51 @@ ExecutorInfoCache &ExecutorInfoCache::Instance() {
return g_exe_cache_info_map;
}

void ExecutorInfoCache::Finalize() {
// NOTE(Aurelius84): DO NOT perform finalize in destructor
// to avoid problems caused by destructor order of static
// object.
info_map_.clear();
}

CacheInfo GetExecutorInfoFromCache(const ExecutorInfoCache::CacheKey &cache_key,
CacheInfo GetExecutorInfoFromCache(const ProgramDesc &program_desc,
const platform::Place &place,
int64_t start_op_index, int64_t end_op_index,
bool is_grad, int64_t program_id,
framework::Scope *scope) {
auto &cached_exe_info = framework::ExecutorInfoCache::Instance();

if (!cached_exe_info.Has(cache_key)) {
VLOG(1) << "create exe_info for " << cache_key.DebugString();

if (!cached_exe_info.Has(program_id, is_grad)) {
// TODO(Aurelius84): Consider to use LRU algorithm to replace this.
if (cached_exe_info.Size() > 4u /* max_cached_size*/) {
VLOG(2) << "The cached info size has exceeded max_cached_size: 4, clear "
"all cache!";
cached_exe_info.Finalize();
}

framework::BuildStrategy build_strategy;
auto execution_strategy = details::GetExecutionStrategy(cache_key);
VLOG(1) << "create exe_info for " << program_id << " is_grad: " << is_grad;
auto execution_strategy = details::GetExecutionStrategy(place);
auto &build_strategy = cached_exe_info.GetBuildStrategy(program_id);

// 2. Construct Graph and ParallelExecutor.
auto graph = std::make_shared<framework::ir::Graph>(
*cache_key.program_desc_, cache_key.start_op_index_,
cache_key.end_op_index_);
program_desc, start_op_index, end_op_index);
auto parallel_executor = std::make_shared<framework::ParallelExecutor>(
cache_key.place_, scope, execution_strategy, build_strategy,
graph.get());
place, scope, execution_strategy, build_strategy, graph.get());
parallel_executor->PrepareVariables(scope);

framework::ExecutorInfoCache::ValueType cache_val = {parallel_executor,
graph};
cached_exe_info.Insert(cache_key, cache_val);

bool is_new_created = true;
return std::make_pair(parallel_executor, is_new_created);
// 3. Insert value into cached map.
auto &cached_value = cached_exe_info.GetMutable(program_id, is_grad);
cached_value.executor_ = parallel_executor;
cached_value.graph_ = std::move(graph);
return std::make_pair(parallel_executor, /*is_new_created=*/true);
} else {
VLOG(1) << "get exe_info from cache by: " << cache_key.DebugString();
bool is_new_created = false;
auto cache_val = cached_exe_info.GetMutable(cache_key);
auto parallel_executor = cache_val.first;
VLOG(1) << "get exe_info from cache by: " << program_id
<< " is_grad: " << is_grad;
auto &cached_value = cached_exe_info.GetMutable(program_id, is_grad);

auto &parallel_executor = cached_value.executor_;
// update op_handle scope_map in pe->executor_->Graph
std::unordered_map<Scope *, Scope *> scope_map = {
{parallel_executor->GetLocalScopes().front(), scope}};
parallel_executor->ResetOpHandleScopeMapOfGraphs(scope_map);
// need to recreate tmp variables in new scope
parallel_executor->PrepareVariables(scope);

return std::make_pair(parallel_executor, is_new_created);
return std::make_pair(parallel_executor, /*is_new_created=*/false);
}
}

Expand Down
153 changes: 62 additions & 91 deletions paddle/fluid/framework/executor_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,121 +45,92 @@ void ParseSafeEagerDeletionSkipVars(
std::vector<std::string>* skip_eager_delete_vars);

} // namespace details
class ExecutorInfoCache {

class ExecutorInfo {
public:
struct CacheKey {
CacheKey(const ProgramDesc* program_desc, const platform::Place& place,
int64_t start_op_index, int64_t end_op_index, bool is_grad)
: program_desc_(program_desc),
place_(place),
start_op_index_(start_op_index),
end_op_index_(end_op_index),
is_grad_(is_grad) {
device_type_ = platform::Place2DeviceType(place);
PADDLE_ENFORCE_NOT_NULL(program_desc_,
"program_desc should not be null.");
}

std::string DebugString() const {
std::stringstream ss;

ss << "\n CacheKey(program_desc: " << program_desc_;
ss << ", start_op_index: " << start_op_index_;
ss << ", end_op_index: " << end_op_index_;
ss << ", is_grad: " << is_grad_;
ss << ", device_type: " << device_type_ << ")";

return ss.str();
}

const ProgramDesc* program_desc_;
platform::Place place_;
int64_t start_op_index_;
int64_t end_op_index_;
bool is_grad_;
platform::DeviceType device_type_;
};
struct CacheValue {
std::shared_ptr<ParallelExecutor> executor_{nullptr};
std::shared_ptr<ir::Graph> graph_{nullptr};

using KeyType = size_t;
using ValueType =
std::pair<std::shared_ptr<ParallelExecutor>, std::shared_ptr<ir::Graph>>;

struct KeyHasher {
size_t operator()(const CacheKey& key) const noexcept {
size_t seed = 10;
auto* prog_desc = key.program_desc_;
/*
* Note(Aurelius84): DO NOT use only ProgramDesc* to calculate hash value
* because a new program will hold same pointer address after an older
* program is destructed with a small probability. Add op size while
* hashing because program may contains at least one block.
*/
hash_combine(&seed, prog_desc);
for (size_t i = 0; i < prog_desc->Size(); ++i) {
hash_combine(&seed, &prog_desc->Block(i));
hash_combine(&seed, prog_desc->Block(i).OpSize());
}
hash_combine(&seed, static_cast<int>(key.device_type_));
hash_combine(&seed, key.start_op_index_);
hash_combine(&seed, key.end_op_index_);
hash_combine(&seed, key.is_grad_);
VLOG(3) << "hash value is : " << seed
<< " of key: " << key.DebugString();
return seed;
}

template <typename T>
void hash_combine(size_t* seed, const T& val) const {
std::hash<T> hasher;
(*seed) ^= hasher(val) + 0x9e3779b9 + ((*seed) << 6) + ((*seed >> 2));
}
std::vector<std::string> skip_eager_delete_vars_;
};

bool IsAvailable(bool is_grad) {
const auto& executor =
is_grad ? backward_info_.executor_ : forward_info_.executor_;
return executor != nullptr;
}

CacheValue& GetMutable(bool is_grad) {
return is_grad ? backward_info_ : forward_info_;
}

private:
CacheValue forward_info_;
CacheValue backward_info_;
};

class ExecutorInfoCache {
public:
static ExecutorInfoCache& Instance();

ValueType GetMutable(const CacheKey& key) {
auto key_val = key_hash_func_(key);
const BuildStrategy& GetBuildStrategy(int64_t program_id) {
// If not found, insert build_strategy with default value.
return strategy_map_[program_id];
}

void SetBuildStrategy(int64_t program_id,
const BuildStrategy& build_strategy) {
PADDLE_ENFORCE_EQ(
Has(key_val), true,
platform::errors::NotFound("%s doesn't exist in ExecutorInfoCache",
key.DebugString()));
return info_map_[key_val];
strategy_map_.count(program_id), 0,
platform::errors::PreconditionNotMet(
"program_id: %s already exist in ExecutorInfoCache", program_id));
strategy_map_[program_id] = build_strategy;
}

bool Has(const CacheKey& key) const {
auto key_val = key_hash_func_(key);
return Has(key_val);
bool Has(int64_t program_id, bool is_grad) {
return info_map_.find(program_id) != info_map_.end() &&
info_map_[program_id].IsAvailable(is_grad);
}

bool Has(const KeyType& key) const {
return info_map_.find(key) != info_map_.end();
ExecutorInfo::CacheValue& GetMutable(int64_t program_id, bool is_grad) {
return info_map_[program_id].GetMutable(is_grad);
}

void Insert(const CacheKey& key, ValueType value) {
auto key_val = key_hash_func_(key);
PADDLE_ENFORCE_EQ(
Has(key_val), false,
platform::errors::NotFound("%s has existed in ExecutorInfoCache",
key.DebugString()));
info_map_.insert({key_val, value});
void UpdateSkipEagerDeleteVars(int64_t program_id, bool is_grad,
const std::vector<std::string>& skip_vars) {
auto& cached_value = GetMutable(program_id, is_grad);
cached_value.skip_eager_delete_vars_ = std::move(skip_vars);
}

std::vector<std::string>& SkipEagerDeleteVars(int64_t program_id,
bool is_grad) {
auto& cached_value = GetMutable(program_id, is_grad);
return cached_value.skip_eager_delete_vars_;
}

size_t Size() const { return info_map_.size(); }

void Finalize();
void Finalize() {
// NOTE(Aurelius84): DO NOT perform finalize in destructor
// to avoid problems caused by destructor order of static
// object.
info_map_.clear();
strategy_map_.clear();
}

private:
ExecutorInfoCache() = default;
DISABLE_COPY_AND_ASSIGN(ExecutorInfoCache);

KeyHasher key_hash_func_;
std::unordered_map<KeyType, ValueType> info_map_;
std::unordered_map<int64_t, ExecutorInfo> info_map_;
std::unordered_map<int64_t, BuildStrategy> strategy_map_;
};

using CacheInfo =
std::pair<std::shared_ptr<ParallelExecutor>, bool /*is_new_created*/>;

CacheInfo GetExecutorInfoFromCache(const ExecutorInfoCache::CacheKey& cache_key,
CacheInfo GetExecutorInfoFromCache(const ProgramDesc& program_desc,
const platform::Place& place,
int64_t start_op_index, int64_t end_op_index,
bool is_grad, int64_t program_id,
framework::Scope* scope);

} // namespace framework
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/operators/run_program_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) Set to true for inference only, false "
"for training.")
.SetDefault(false);
AddAttr<int64_t>(
"program_id",
"(int64_t)"
"The unique hash id used as cache key for ExecutorInfoCache.");
AddComment(R"DOC(
RunProgram operator.

Expand Down
Loading