Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python Generate OpCreation Methods by OpProto #2893

Merged
merged 5 commits into from
Jul 17, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <algorithm>
#include <atomic>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
Expand Down Expand Up @@ -214,25 +215,35 @@ class OpRegistry {
}

static OperatorPtr CreateOp(const OpDesc& op_desc) {
//! Create a OpPtr by type.
std::string op_type = op_desc.type();
OperatorPtr op(creators().at(op_type)());
//! Fill op's data member. Not use constructor because it will be noising
//! for Op developer.
const OpProto& op_proto = protos().at(op_type);
// set op's inputs_ from desc.
op->type_ = op_desc.type();
// set op's inputs_ from desc.
op->inputs_.reserve((size_t)op_desc.inputs_size());
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
std::back_inserter(op->inputs_));
// set op's outputs_ from desc.
op->outputs_.reserve((size_t)op_desc.outputs_size());
std::copy(op_desc.outputs().begin(), op_desc.outputs().end(),
std::back_inserter(op->outputs_));
// set op's attr;

//! Fill attrs, and validate attrs.
for (auto& attr : op_desc.attrs()) {
op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
}
op_checkers().at(op_type).Check(op->attrs_);

//! Convert Temporary variable name to an unique variable name.
GenerateTempVariableName(op.get());
Copy link
Member

@jacquesqiao jacquesqiao Jul 17, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为temp var name是c++这边自动生成的,所以python那边就管不了这些Var了是么,这些Var应该如何创建,被谁创建?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python这边会在运行Op之前扫描这个Op的输入和输出,进而创建对应的variable


// set argument offsets stored in op.
CreateInOutOffsetMap(op, op_proto);
//! Other op's custom Init for a complex Op. For simple Op, the Init
//! method do nothing.
op->Init();
return op;
}
Expand All @@ -248,6 +259,17 @@ class OpRegistry {
};

private:
static void GenerateTempVariableName(OperatorBase* op) {
static std::atomic<size_t> gUniqId(0UL);
for (auto& outname : op->outputs_) {
if (outname == OperatorBase::TMP_VAR_NAME()) {
outname += op->type_;
outname += "@";
outname += std::to_string(gUniqId.fetch_add(1));
}
}
}

static std::unordered_map<std::string, OpCreator>& creators() {
static std::unordered_map<std::string, OpCreator> creators_;
return creators_;
Expand Down
28 changes: 13 additions & 15 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,23 +77,21 @@ std::vector<std::string> OperatorBase::Outputs(const std::string& name) const {

std::string OperatorBase::DebugString() const {
std::stringstream ss;
ss << "=================\n";
ss << "type = " << type_ << "\n";
ss << "inputs = [";
for (auto& ipt : inputs_) {
ss << ipt << ", ";
ss << "Op(" << type_ << "), inputs:(";
for (size_t i = 0; i < inputs_.size(); ++i) {
ss << inputs_[i];
if (i != inputs_.size() - 1) {
ss << ", ";
}
}
ss << "]\n";
ss << "outputs = [";
for (auto& opt : outputs_) {
ss << opt << ", ";
ss << "), outputs:(";
for (size_t i = 0; i < outputs_.size(); ++i) {
ss << outputs_[i];
if (i != outputs_.size() - 1) {
ss << ", ";
}
}
ss << "]\n";
ss << "attr_keys = [";
for (auto& attr : attrs_) {
ss << attr.first << ", ";
}
ss << "]\n";
ss << ").";
return ss.str();
}

Expand Down
7 changes: 7 additions & 0 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ using OperatorPtr = std::shared_ptr<OperatorBase>;
*/
class OperatorBase {
public:
/// If a variable is a empty variable, that name will be used.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should add some more comment about what is empty variable and temp variable

static std::string EMPTY_VAR_NAME() { return "@EMPTY@"; }

/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
static std::string TMP_VAR_NAME() { return "@TEMP@"; }

virtual ~OperatorBase() {}

template <typename T>
Expand Down
17 changes: 17 additions & 0 deletions paddle/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,23 @@ All parameter, weight, gradient are variables in Paddle.
}
return ret_values;
});
m.def_submodule(
"var_names",
"The module will return special predefined variable name in Paddle")
.def("empty", pd::OperatorBase::EMPTY_VAR_NAME)
.def("temp", pd::OperatorBase::TMP_VAR_NAME);

py::class_<pd::OperatorBase, pd::OperatorPtr>(m, "Operator")
.def("__str__", &pd::OperatorBase::DebugString)
.def_static("create", [](const std::string& protobin) {
pd::OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
return pd::OpRegistry::CreateOp(desc);
});

return m.ptr();
}
235 changes: 235 additions & 0 deletions python/paddle/v2/framework/create_op_creation_methods.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,246 @@
import paddle.v2.framework.core as core
import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2
import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2
import paddle.v2.framework.proto.attr_type_pb2 as attr_type_pb2
import cStringIO


def get_all_op_protos():
"""
Get all registered op proto from Paddle C++
:return: list of OpProto
"""
protostrs = core.get_all_op_protos()
ret_values = []
for pbstr in protostrs:
op_proto = op_proto_pb2.OpProto.FromString(str(pbstr))
ret_values.append(op_proto)
return ret_values


class OpDescCreationMethod(object):
Copy link
Member

@jacquesqiao jacquesqiao Jul 17, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class OpDescCreator(object):
  def create(self, *args, **kwargs):
    ...
    return op_desc     

I think above is better for reading and understanding code.

Copy link
Collaborator Author

@reyoung reyoung Jul 17, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A Function object(https://en.wikipedia.org/wiki/Function_object) or functor is a very common way for many programming languages.

Using Function object can hold internal states and make function implementation into small pieces and more readable.

"""
A Functor object to convert user input(use key word args) to OpDesc based on
OpProto.

:param op_proto: The OpProto object.
:type op_proto: op_proto_pb2.OpProto
"""

def __init__(self, op_proto):
if not isinstance(op_proto, op_proto_pb2.OpProto):
raise TypeError("Argument should be OpProto")
self.__op_proto__ = op_proto

def __call__(self, *args, **kwargs):
"""
Convert user input to OpDesc. Only key-word args are supported.
:return: OpDesc based on user input
:rtype: op_desc_pb2.OpDesc
"""
if len(args) != 0:
raise ValueError("Only keyword arguments is supported by Paddle")
op_desc = op_desc_pb2.OpDesc()

# Inputs
ipts, ipt_format, _ = OpDescCreationMethod.extract_input_or_output(
"input", kwargs, self.__op_proto__.inputs)
op_desc.inputs.extend(ipts)
if ipt_format is not None:
op_desc.attrs.extend([ipt_format])

# Outputs
outs, out_format, tmp_index = OpDescCreationMethod.extract_input_or_output(
"output", kwargs, self.__op_proto__.outputs)
op_desc.outputs.extend(outs)
if out_format is not None:
op_desc.attrs.extend([out_format])
if len(tmp_index) != 0:
tmp_index_attr = op_desc.attrs.add()
tmp_index_attr.type = attr_type_pb2.INTS
tmp_index_attr.name = "temporary_index"
tmp_index_attr.ints.extend(tmp_index)

# Types
op_desc.type = self.__op_proto__.type

# Attrs
for attr in self.__op_proto__.attrs:
if attr.generated:
continue
user_defined_attr = kwargs.get(attr.name, None)
if user_defined_attr is not None:
new_attr = op_desc.attrs.add()
new_attr.name = attr.name
new_attr.type = attr.type
if attr.type == attr_type_pb2.INT:
new_attr.i = user_defined_attr
elif attr.type == attr_type_pb2.FLOAT:
new_attr.f = user_defined_attr
elif attr.type == attr_type_pb2.STRING:
new_attr.s = user_defined_attr
elif attr.type == attr_type_pb2.INTS:
new_attr.ints.extend(user_defined_attr)
elif attr.type == attr_type_pb2.FLOATS:
new_attr.floats.extend(user_defined_attr)
elif attr.type == attr_type_pb2.STRINGS:
new_attr.strings.extend(user_defined_attr)
else:
raise NotImplementedError("Not support attribute type " +
attr.type)

return op_desc

@staticmethod
def extract_input_or_output(in_out, kwargs, meta):
"""
Extract input variable names or output variable names from key-word
arguments, which base on VarProtos.

:param in_out: "input" or "output"
:param kwargs: key-word arguments that user inputted.
:param meta: a list of VarProto
:return: The three object will be return. The variable names. The
input_format or output_format attribute(None if the input or output is
not multiple). The temporary variable index list.
"""
multiple = OpDescCreationMethod.any_is_true((m.multiple for m in meta))
tmp_index = []
retv = []
if multiple:
var_format = op_desc_pb2.AttrDesc()
var_format.type = attr_type_pb2.INTS
var_format.name = "%s_format" % in_out
var_format.ints.append(0)

for var in meta:
var_name = var.name

if var.temporary:
var_name = [core.var_names.temp()]
tmp_index.append(len(retv))
else:
var_name = kwargs.get(var_name, [])
if not isinstance(var_name, list):
var_name = [var_name]
retv.extend(var_name)
var_format.ints.append(len(var_name) + var_format.ints[-1])
return retv, var_format, tmp_index
else:
for var in meta:
if var.temporary:
retv.append(kwargs.get(var.name, core.var_names.temp()))
tmp_index.append(len(retv))
else:
retv.append(kwargs.get(var.name, core.var_names.empty()))
return retv, None, tmp_index

@staticmethod
def any_is_true(generator):
"""
Reduce a bool array to one. If any of them is True, then return True.
"""
for flag in generator:
if flag:
return True
return False


def get_docstring_from_op_proto(op_proto):
"""
Generate docstring from a OpProto
:param op_proto: a OpProto instance.
:type op_proto: op_proto_pb2.OpProto
:return: docstring
"""
if not isinstance(op_proto, op_proto_pb2.OpProto):
raise TypeError("Input must be OpProto")
f = cStringIO.StringIO()
f.write(op_proto.comment)
f.write("\n")

def __append_param__(name, comment, type):
# Maybe replace the following line with template engine is better.
f.write(":param ")
f.write(name)
f.write(": ")
f.write(comment)
f.write("\n")
f.write(":type ")
f.write(name)
f.write(": ")
f.write(type)
f.write("\n")

for ipt in op_proto.inputs:
__append_param__(ipt.name, ipt.comment, "list | basestr"
if ipt.multiple else "basestr")

temp_var_prefix = \
"This is a temporary variable. It does not have to set by user. "
for opt in op_proto.outputs:
__append_param__(opt.name, opt.comment if not opt.temporary else
temp_var_prefix + opt.comment, "list | basestr"
if opt.multiple else "basestr")

for attr in op_proto.attrs:
attr_type = None
if attr.type == attr_type_pb2.INT:
attr_type = "int"
elif attr.type == attr_type_pb2.FLOAT:
attr_type = "float"
elif attr.type == attr_type_pb2.STRING:
attr_type = "basestr"
elif attr.type == attr_type_pb2.INTS:
attr_type = "list of int"
elif attr.type == attr_type_pb2.FLOATS:
attr_type = "list of float"
elif attr.type == attr_type_pb2.STRINGS:
attr_type = "list of basestr"

if attr_type is None:
raise RuntimeError("Not supported attribute type " + attr.type)

__append_param__(attr.name, attr.comment, attr_type)

return f.getvalue()


def create_op_creation_method(op_proto):
"""
Generate op creation method for an OpProto
"""
method = OpDescCreationMethod(op_proto)

def __impl__(*args, **kwargs):
opdesc = method(*args, **kwargs)
return core.Operator.create(opdesc.SerializeToString())

__impl__.__doc__ = get_docstring_from_op_proto(op_proto)
return __impl__


class OpCreationsHolder(object):
"""
A object will holds all op creation methods.

Use `op_creations.xxx_op` to access them.
"""
pass


op_creations = OpCreationsHolder()


def __bootstrap__():
"""
Bootstrap function for this module. It will dynamic create all op creation
methods in runtime.
"""
for op_proto in get_all_op_protos():
func = create_op_creation_method(op_proto)
func.__name__ = str(op_proto.type)
setattr(op_creations, func.__name__, func)


__bootstrap__()
Loading