Skip to content

Commit

Permalink
Merge branch 'op-device-placer' into 'master'
Browse files Browse the repository at this point in the history
Feature: Add registration of Op conditions.

See merge request !942
  • Loading branch information
李寅 committed Jan 3, 2019
2 parents 8a468a9 + 4a298aa commit aee1f76
Show file tree
Hide file tree
Showing 31 changed files with 379 additions and 170 deletions.
136 changes: 39 additions & 97 deletions mace/core/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,92 +64,26 @@ bool TransformRequiredOp(const std::string &op_type) {
}
#endif // MACE_ENABLE_OPENCL



// TODO(lichao): Move to runtime driver class after universality done.
// fallback to gpu buffer when kernels are implemented
void FindAvailableDevicesForOp(const OpRegistryBase &op_registry,
const OperatorDef &op,
const std::unordered_map<std::string,
std::vector<index_t>> &tensor_shape_info,
std::set<DeviceType>
*available_devices) {
auto devices = op_registry.AvailableDevices(op.type());
available_devices->insert(devices.begin(), devices.end());
std::string op_type = op.type();
// For those whose shape is not 4-rank but can run on GPU
if (op_type == "BufferTransform"
|| op_type == "LSTMCell"
|| op_type == "FullyConnected"
|| op_type == "Softmax"
|| op_type == "Squeeze") {
return;
} else {
if (op.output_shape_size() != op.output_size()) {
return;
}
if (op.output_shape(0).dims_size() != 4) {
available_devices->erase(DeviceType::GPU);
}

if (op_type == "Split") {
if (op.output_shape(0).dims_size() != 4
|| op.output_shape(0).dims()[3] % 4 != 0) {
available_devices->erase(DeviceType::GPU);
}
} else if (op_type == "Concat") {
if (op.output_shape(0).dims_size() != 4) {
available_devices->erase(DeviceType::GPU);
} else {
if (op.input_size() != 2) {
for (const std::string &input : op.input()) {
if (tensor_shape_info.find(input) != tensor_shape_info.end()) {
auto &input_shape = tensor_shape_info.at(input);
if (input_shape[3] % 4 != 0) {
available_devices->erase(DeviceType::GPU);
break;
}
}
}
}
}
} else if (op_type == "ChannelShuffle") {
int groups = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op, "group", 1);
int channels = op.output_shape(0).dims(3);
int channels_per_group = channels / groups;
if (groups % 4 != 0 || channels_per_group % 4 != 0) {
available_devices->erase(DeviceType::GPU);
}
}
}
}

} // namespace

std::unique_ptr<Operation> SerialNet::CreateOperation(
const OpRegistryBase *op_registry,
OpConstructContext *construct_context,
std::shared_ptr<OperatorDef> op_def,
const std::unordered_map<std::string,
std::vector<index_t>> tensor_shape_info,
DataFormat data_format_flag,
bool is_quantize_model) {
// Create the Operation
DeviceType target_device_type = target_device_->device_type();
// Get available devices
std::set<DeviceType> available_devices;
FindAvailableDevicesForOp(*op_registry,
*op_def,
tensor_shape_info,
&available_devices);
// Find the device type to run the op.
// If the target_device_type in available devices, use target_device_type,
// otherwise, fallback to CPU device.
DeviceType device_type = DeviceType::CPU;
construct_context->set_device(cpu_device_);
construct_context->set_operator_def(op_def);
construct_context->set_output_mem_type(MemoryType::CPU_BUFFER);
// Get available devices
auto available_devices =
op_registry->AvailableDevices(op_def->type(), construct_context);
// Find the device type to run the op.
// If the target_device_type in available devices, use target_device_type,
// otherwise, fallback to CPU device.
for (auto device : available_devices) {
if (device == target_device_type) {
device_type = target_device_type;
Expand Down Expand Up @@ -208,6 +142,23 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
MemoryType target_mem_type;
// quantize model flag
bool is_quantize_model = IsQuantizedModel(*net_def);
// Tensor Shape map
std::unordered_map<std::string, std::vector<index_t>> tensor_shape_map;
for (auto &op : net_def->op()) {
if (op.output_size() != op.output_shape_size()) {
continue;
}
for (int i = 0; i < op.output_size(); ++i) {
tensor_shape_map[op.output(i)] =
std::move(std::vector<index_t>(op.output_shape(i).dims().begin(),
op.output_shape(i).dims().end()));
}
}
for (auto &tensor : net_def->tensors()) {
tensor_shape_map[tensor.name()] =
std::move(std::vector<index_t>(tensor.dims().begin(),
tensor.dims().end()));
}

DataFormat data_format_flag = NHWC;
if (target_device_->device_type() == DeviceType::CPU) {
Expand All @@ -216,11 +167,12 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
std::vector<index_t> input_shape =
std::vector<index_t>(input_info.dims().begin(),
input_info.dims().end());
// update tensor shape map
tensor_shape_map[input_info.name()] = input_shape;
// Only could be NONE or NHWC
auto input_data_format = static_cast<DataFormat>(
input_info.data_format());
if (!is_quantize_model &&
input_data_format == NHWC &&
if (!is_quantize_model && input_data_format == NHWC &&
input_info.dims_size() == 4) {
// NHWC -> NCHW
input_shape =
Expand All @@ -237,39 +189,29 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
else { // GPU NOLINT[readability/braces]
target_mem_type = MemoryType::GPU_BUFFER;
for (auto &input_info : net_def->input_info()) {
auto input_data_format = static_cast<DataFormat>(
input_info.data_format());
if (input_data_format == DataFormat::DF_NONE) {
data_format_flag = DataFormat::DF_NONE;
}
std::vector<index_t> input_shape =
std::vector<index_t>(input_info.dims().begin(),
input_info.dims().end());
// update tensor shape map
tensor_shape_map[input_info.name()] = input_shape;
output_map.emplace(input_info.name(), InternalOutputInfo(
target_mem_type, DataType::DT_FLOAT, input_shape, -1));
}
}
#endif // MACE_ENABLE_OPENCL

std::unordered_map<std::string, std::vector<index_t>> tensor_shape_info;
for (auto &op : net_def->op()) {
if (op.output_size() != op.output_shape_size()) {
continue;
}
for (int i = 0; i < op.output_size(); ++i) {
tensor_shape_info[op.output(i)] =
std::move(std::vector<index_t>(op.output_shape(i).dims().begin(),
op.output_shape(i).dims().end()));
}
}
for (auto &tensor : net_def->tensors()) {
tensor_shape_info[tensor.name()] =
std::move(std::vector<index_t>(tensor.dims().begin(),
tensor.dims().end()));
}
OpConstructContext construct_context(ws_);
OpConstructContext construct_context(ws_, &tensor_shape_map);
for (int idx = 0; idx < net_def->op_size(); ++idx) {
std::shared_ptr<OperatorDef> op_def(new OperatorDef(net_def->op(idx)));
// Create operation
auto op = CreateOperation(op_registry,
&construct_context,
op_def,
tensor_shape_info,
data_format_flag,
is_quantize_model);
#ifdef MACE_ENABLE_OPENCL
Expand Down Expand Up @@ -317,12 +259,12 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
}
auto transform_op_def = OpenCLUtil::CreateTransformOpDef(
input_name, input_shape, t_input_name,
wanted_in_dt, wanted_in_mem_type);
wanted_in_dt, wanted_in_mem_type, data_format_flag);
OpConstructContext t_construct_context(ws_);
auto transform_op = CreateOperation(
op_registry,
&construct_context,
&t_construct_context,
transform_op_def,
tensor_shape_info,
data_format_flag);
operators_.emplace_back(std::move(transform_op));
transformed_set.insert(t_input_name);
Expand Down Expand Up @@ -405,12 +347,12 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
internal_output_info.shape,
output_info.name(),
output_info.data_type(),
target_mem_type);
target_mem_type,
data_format_flag);
auto transform_op = CreateOperation(
op_registry,
&construct_context,
transform_op_def,
tensor_shape_info,
output_data_format);
operators_.emplace_back(std::move(transform_op));
// where to do graph reference count.
Expand Down
2 changes: 0 additions & 2 deletions mace/core/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ class SerialNet : public NetBase {
const OpRegistryBase *op_registry,
OpConstructContext *construct_context,
std::shared_ptr<OperatorDef> op_def,
const std::unordered_map<std::string,
std::vector<index_t>> tensor_shape_info,
DataFormat input_format,
bool is_quantize_model = false);

Expand Down
71 changes: 63 additions & 8 deletions mace/core/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,18 @@
namespace mace {

OpConstructContext::OpConstructContext(Workspace *ws)
: operator_def_(nullptr), ws_(ws), device_(nullptr) {}
: operator_def_(nullptr),
ws_(ws),
device_(nullptr),
tensor_shape_info_(nullptr) {}

OpConstructContext::OpConstructContext(
mace::Workspace *ws,
mace::OpConstructContext::TensorShapeMap *info)
: operator_def_(nullptr),
ws_(ws),
device_(nullptr),
tensor_shape_info_(info) {}

void OpConstructContext::set_operator_def(
std::shared_ptr<mace::OperatorDef> operator_def) {
Expand Down Expand Up @@ -169,6 +180,19 @@ const std::string OpKeyBuilder::Build() {
}
} // namespace

OpRegistrationInfo::OpRegistrationInfo() {
device_placer = [this](OpConstructContext *context) -> std::set<DeviceType> {
auto op = context->operator_def();
// The GPU ops only support 4D In/Out tensor by default
if (this->devices.count(DeviceType::CPU) == 1 &&
op->output_shape_size() == op->output_size() &&
op->output_shape(0).dims_size() != 4) {
return { DeviceType::CPU };
}
return this->devices;
};
}

void OpRegistrationInfo::AddDevice(mace::DeviceType device) {
devices.insert(device);
}
Expand All @@ -179,10 +203,11 @@ void OpRegistrationInfo::Register(const std::string &key, OpCreator creator) {
creators[key] = creator;
}

MaceStatus OpRegistryBase::Register(const std::string &op_type,
const mace::DeviceType device_type,
const mace::DataType dt,
mace::OpRegistrationInfo::OpCreator creator) {
MaceStatus OpRegistryBase::Register(
const std::string &op_type,
const mace::DeviceType device_type,
const mace::DataType dt,
mace::OpRegistrationInfo::OpCreator creator) {
if (registry_.count(op_type) == 0) {
registry_[op_type] = std::unique_ptr<OpRegistrationInfo>(
new OpRegistrationInfo);
Expand All @@ -197,15 +222,25 @@ MaceStatus OpRegistryBase::Register(const std::string &op_type,
return MaceStatus::MACE_SUCCESS;
}

MaceStatus OpRegistryBase::Register(
const OpConditionBuilder &builder) {
std::string op_type = builder.type();
if (registry_.count(op_type) == 0) {
registry_[op_type] = std::unique_ptr<OpRegistrationInfo>(
new OpRegistrationInfo);
}
builder.Finalize(registry_[op_type].get());
return MaceStatus::MACE_SUCCESS;
}

const std::set<DeviceType> OpRegistryBase::AvailableDevices(
const std::string &op_type) const {
const std::string &op_type, OpConstructContext *context) const {
MACE_CHECK(registry_.count(op_type) != 0,
op_type, " operation is not registered.");

return registry_.at(op_type)->devices;
return registry_.at(op_type)->device_placer(context);
}


std::unique_ptr<Operation> OpRegistryBase::CreateOperation(
OpConstructContext *context,
DeviceType device_type) const {
Expand Down Expand Up @@ -238,4 +273,24 @@ std::unique_ptr<Operation> OpRegistryBase::CreateOperation(
}
return registry_.at(op_type)->creators.at(key)(context);
}

OpConditionBuilder::OpConditionBuilder(const std::string &type)
: type_(type) {}

const std::string OpConditionBuilder::type() const {
return type_;
}

OpConditionBuilder &OpConditionBuilder::SetDevicePlacerFunc(
OpRegistrationInfo::DevicePlacer placer) {
placer_ = placer;
return *this;
}

void OpConditionBuilder::Finalize(OpRegistrationInfo *info) const {
if (info != nullptr && placer_) {
info->device_placer = placer_;
}
}

} // namespace mace
Loading

0 comments on commit aee1f76

Please sign in to comment.