diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index 21c60a3c86438..fe0c87bc57082 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/framework/op_registry.h" #include "paddle/framework/tensor_py.h" #include "paddle/operators/net_op.h" +#include "paddle/operators/recurrent_op.h" #include "paddle/platform/enforce.h" #include "paddle/platform/place.h" #include "paddle/string/to_string.h" @@ -241,6 +242,11 @@ All parameter, weight, gradient are variables in Paddle. const std::shared_ptr &net) -> void { self.AddOp(std::static_pointer_cast(net)); }) + .def("add_op", + [](operators::NetOp &self, + const std::shared_ptr &rnn) -> void { + self.AddOp(std::static_pointer_cast(rnn)); + }) .def("complete_add_op", &operators::NetOp::CompleteAddOp) .def("complete_add_op", [](std::shared_ptr &self) { self->CompleteAddOp(); @@ -248,6 +254,29 @@ All parameter, weight, gradient are variables in Paddle. ExposeOperator(net); + // recurrent_op + py::class_> + rnn(m, "RecurrentOp"); + + rnn.def_static( + "create", + [](py::bytes protobin) -> std::shared_ptr { + OpDesc desc; + PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), + "Cannot parse user input to OpDesc"); + PADDLE_ENFORCE(desc.IsInitialized(), + "User OpDesc is not initialized, reason %s", + desc.InitializationErrorString()); + auto rnn_op = OpRegistry::CreateOp(desc); + return std::dynamic_pointer_cast(rnn_op); + }) + .def("set_stepnet", + [](operators::RecurrentOp &self, + const std::shared_ptr &net) -> void { + self.set_stepnet(net); + }); + ExposeOperator(rnn); + m.def("unique_integer", UniqueIntegerGenerator); m.def("is_compile_gpu", IsCompileGPU); diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index e5ff3b2f7ec37..a7c89787e43df 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -66,6 +66,5 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor op_registry operator net_op) -cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op) op_library(uniform_random_op SRCS uniform_random_op.cc uniform_random_op.cu) diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index f61e1288d3cd8..78ce0ba3c0fa4 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -36,15 +36,13 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const { rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, true /*infer_shape_mode*/); InitMemories(step_scopes[0], true /*infer_shape_mode*/); - Variable* net = scope.FindVar(arg_->step_net); - PADDLE_ENFORCE(net != nullptr, "failed to get step net"); for (size_t i = 0; i < seq_len_; i++) { if (i > 0) { rnn::LinkMemories(step_scopes, arg_->memories, i, -1, true /*infer_shape_mode*/); } - net->GetMutable()->InferShape(*step_scopes[i]); + (*stepnet_)->InferShape(*step_scopes[i]); } rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/); @@ -56,7 +54,6 @@ void RecurrentAlgorithm::Run(const Scope& scope, rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, false /*infer_shape_mode*/); InitMemories(step_scopes[0], false /*infer_shape_mode*/); - Variable* net = scope.FindVar(arg_->step_net); for (size_t step_id = 0; step_id < seq_len_; step_id++) { // create output alias variables @@ -64,7 +61,7 @@ void RecurrentAlgorithm::Run(const Scope& scope, rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1, false /*infer_shape_mode*/); } - net->GetMutable()->Run(*step_scopes[step_id], dev_ctx); + (*stepnet_)->Run(*step_scopes[step_id], dev_ctx); } rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, false /*infer_shape_mode*/); @@ -78,18 +75,16 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { auto step_scopes = step_scopes_var->GetMutable>(); // Now all variables in scope must be created outside of op. - auto net_var = scope.FindVar(arg_->step_net); - PADDLE_ENFORCE(net_var != nullptr, "no stepnet called %s in scope", - arg_->step_net); - auto net_op = net_var->GetMutable(); - PADDLE_ENFORCE(!net_op->Outputs().empty(), "net_op has no outputs"); + PADDLE_ENFORCE_NOT_NULL(stepnet_); + PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "stepnet_ op has no outputs"); + PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "net_op has no outputs"); if (seq_len_ > step_scopes->size()) { for (size_t i = step_scopes->size(); i < seq_len_; ++i) { auto& step_scope = scope.NewScope(); // create step net's temp inputs - for (auto& input : net_op->Inputs()) { + for (auto& input : (*stepnet_)->Inputs()) { // the weight are located in parent scope for (auto& var_name : input.second) { if (!step_scope.FindVar(var_name)) { @@ -98,7 +93,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { } } // create stepnet's outputs - for (const auto& output : net_op->Outputs()) { + for (const auto& output : (*stepnet_)->Outputs()) { for (auto& var_name : output.second) { step_scope.NewVar(var_name); } @@ -140,9 +135,8 @@ RecurrentOp::RecurrentOp(const std::string& type, const framework::OperatorBase::VarNameMap& outputs, const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) { - std::unique_ptr arg(new rnn::Argument()); - rnn::InitArgument(kArgName, arg.get(), *this); - alg_.Init(std::move(arg)); + rnn::InitArgument(kArgName, &arg_, *this); + alg_.Init(&arg_, &stepnet_); } class RecurrentAlgorithmProtoAndCheckerMaker @@ -158,7 +152,6 @@ class RecurrentAlgorithmProtoAndCheckerMaker .AsDuplicable(); AddInput(name.boot_memories, "variables to initialize memories.") .AsDuplicable(); - AddInput(name.step_net, "network shared by all steps."); AddOutput(name.outlinks, "the outputs that need to concated for all steps.") .AsDuplicable(); @@ -180,14 +173,12 @@ void RecurrentGradientAlgorithm::Run( auto step_scopes = GetStepScopes(scope); rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, false /*infer_shape_mode*/); - Variable* net = scope.FindVar(arg_->step_net); - PADDLE_ENFORCE(net != nullptr, "failed to get step net"); for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { if (static_cast(step_id) != seq_len_ - 1) { rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1, false /*infer_shape_mode*/); } - net->GetMutable()->Run(*step_scopes[step_id], dev_ctx); + (*stepnet_)->Run(*step_scopes[step_id], dev_ctx); } LinkBootMemoryGradients(step_scopes[0], false); rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, @@ -219,14 +210,12 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const { auto step_scopes = GetStepScopes(scope); rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, true /*infer_shape_mode*/); - Variable* net = scope.FindVar(arg_->step_net); - PADDLE_ENFORCE(net != nullptr, "failed to get step net"); for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { if (static_cast(step_id) != seq_len_ - 1) { rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1, true /*infer_shape_mode*/); } - net->GetMutable()->InferShape(*step_scopes[step_id]); + (*stepnet_)->InferShape(*step_scopes[step_id]); } rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/); @@ -238,9 +227,8 @@ RecurrentGradientOp::RecurrentGradientOp( const framework::OperatorBase::VarNameMap& outputs, const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) { - std::unique_ptr arg(new rnn::Argument()); - rnn::InitArgument(kArgName, arg.get(), *this); - alg_.Init(std::move(arg)); + rnn::InitArgument(kArgName, &arg_, *this); + alg_.Init(&arg_, &stepnet_); } } // namespace operators diff --git a/paddle/operators/recurrent_op.h b/paddle/operators/recurrent_op.h index 8f4f2444d844b..caca644c96c3f 100644 --- a/paddle/operators/recurrent_op.h +++ b/paddle/operators/recurrent_op.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/framework/operator.h" +#include "paddle/operators/net_op.h" #include "paddle/operators/rnn/recurrent_op_utils.h" namespace paddle { @@ -33,7 +34,11 @@ class RecurrentAlgorithm { void Run(const framework::Scope& scope, const platform::DeviceContext& dev_ctx) const; - void Init(std::unique_ptr arg) { arg_ = std::move(arg); } + void Init(rnn::Argument* arg, std::shared_ptr* stepnet) { + PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before."); + arg_ = arg; + stepnet_ = stepnet; + } /** * InferShape must be called before Run. @@ -58,7 +63,8 @@ class RecurrentAlgorithm { void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const; private: - std::unique_ptr arg_; + std::shared_ptr* stepnet_; + rnn::Argument* arg_; mutable size_t seq_len_; }; @@ -74,7 +80,11 @@ class RecurrentGradientAlgorithm { * operator. */ public: - void Init(std::unique_ptr arg) { arg_ = std::move(arg); } + void Init(rnn::Argument* arg, std::shared_ptr* stepnet) { + PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before."); + arg_ = std::move(arg); + stepnet_ = stepnet; + } void Run(const framework::Scope& scope, const platform::DeviceContext& dev_ctx) const; @@ -95,8 +105,9 @@ class RecurrentGradientAlgorithm { } private: - std::unique_ptr arg_; + rnn::Argument* arg_; mutable size_t seq_len_; + std::shared_ptr* stepnet_; }; class RecurrentOp final : public framework::OperatorBase { @@ -115,10 +126,15 @@ class RecurrentOp final : public framework::OperatorBase { alg_.Run(scope, dev_ctx); } + void set_stepnet(std::shared_ptr net) { stepnet_ = net; } + const NetOp* stepnet() const { return stepnet_.get(); } + static const rnn::ArgumentName kArgName; private: RecurrentAlgorithm alg_; + rnn::Argument arg_; + std::shared_ptr stepnet_; }; class RecurrentGradientOp final : public framework::OperatorBase { @@ -141,8 +157,13 @@ class RecurrentGradientOp final : public framework::OperatorBase { static const rnn::ArgumentName kArgName; + void set_stepnet(const std::shared_ptr& net) { stepnet_ = net; } + const NetOp* stepnet() const { return stepnet_.get(); } + private: RecurrentGradientAlgorithm alg_; + std::shared_ptr stepnet_; + rnn::Argument arg_; }; } // namespace operators diff --git a/paddle/operators/recurrent_op_test.cc b/paddle/operators/recurrent_op_test.cc deleted file mode 100644 index 2f6eff0720847..0000000000000 --- a/paddle/operators/recurrent_op_test.cc +++ /dev/null @@ -1,252 +0,0 @@ -/* - Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -#include "paddle/operators/recurrent_op.h" - -#include -#include - -#include "paddle/framework/ddim.h" -#include "paddle/framework/op_registry.h" -#include "paddle/framework/operator.h" -#include "paddle/framework/tensor.h" -#include "paddle/operators/net_op.h" - -namespace paddle { -namespace operators { - -using namespace paddle::framework; - -class RecurrentGradientAlgorithmTest : public ::testing::Test { - protected: - virtual void SetUp() override { - CreateGlobalVariables(); - CreateStepScopes(); - CreateStepNet(); - CreateRNNGradientAlgorithm(); - - // segment inputs - SegmentInputs(); - // link forward memories - LinkeMemories(); - } - - virtual void TearDown() override {} - - void CreateGlobalVariables() { - // inputs: x - LOG(INFO) << "create global variable x"; - Variable* x = scope_.NewVar("x"); - DDim dims = - make_ddim({10 /*sent size*/, 20 /*batch size*/, 30 /*input dim*/}); - x->GetMutable()->mutable_data(dims, platform::CPUPlace()); - // inputs: h_boot - LOG(INFO) << "create global variable h_boot"; - Variable* h_boot = scope_.NewVar("h_boot"); - h_boot->GetMutable()->mutable_data( - make_ddim({20 /*batch size*/, 30 /*input dim*/}), platform::CPUPlace()); - // inputs: w - LOG(INFO) << "create global variable w"; - Variable* w = scope_.NewVar("rnn/w"); - w->GetMutable()->mutable_data(make_ddim({30, 30}), - platform::CPUPlace()); - // inputs: h_grad - LOG(INFO) << "create variable h_grad"; - Variable* dh = scope_.NewVar("h_grad"); - dh->GetMutable()->mutable_data(make_ddim({10, 20, 30}), - platform::CPUPlace()); - // inputs: step_scopes - LOG(INFO) << "create variable step_scopes"; - scope_.NewVar("step_scopes"); - // inputs: step_net - LOG(INFO) << "create variable step_net"; - scope_.NewVar("step_net"); - // outputs: w_grad - LOG(INFO) << "create global variable w_grad"; - scope_.NewVar("rnn/w_grad"); - // outputs: x_grad - LOG(INFO) << "create global variable x_grad"; - scope_.NewVar("x_grad"); - // outputs: h_boot_grad - LOG(INFO) << "create global variable h_boot_grad"; - scope_.NewVar("h_boot_grad"); - } - - void CreateStepScopes() { - auto step_scopes = - scope_.FindVar("step_scopes")->GetMutable>(); - for (int i = 0; i < 10; ++i) { - auto& scope = scope_.NewScope(); - auto pre_t = scope.NewVar("rnn/pre_h")->GetMutable(); - pre_t->mutable_data({20, 30}, platform::CPUPlace()); - auto tensor = scope.NewVar("rnn/h")->GetMutable(); - tensor->mutable_data({20, 30}, platform::CPUPlace()); - - // for unit test of ConcatOutputs - auto xg = scope.NewVar("rnn/x_grad")->GetMutable(); - xg->mutable_data({20, 30}, platform::CPUPlace()); - - step_scopes->emplace_back(&scope); - } - - // last time step - auto g = (*step_scopes)[9]->NewVar("rnn/h_pre_grad")->GetMutable(); - g->mutable_data({20, 30}, platform::CPUPlace()); - } - - void CreateRNNGradientAlgorithm() { - std::unique_ptr arg(new rnn::Argument()); - arg->step_net = "step_net"; - arg->step_scopes = "step_scopes"; - rnn::Link inlink; - inlink.external = "h_grad"; - inlink.internal = "rnn/h_grad"; - arg->inlinks = std::vector{inlink}; - - rnn::Link outlink; - outlink.external = "x_grad"; - outlink.internal = "rnn/x_grad"; - arg->outlinks = std::vector{outlink}; - - rnn::MemoryAttr mem_attr; - mem_attr.pre_var = "rnn/h_pre_grad"; - mem_attr.var = "rnn/h_grad"; - mem_attr.boot_var = "h_boot_grad"; - arg->memories = std::vector{mem_attr}; - - rnn_grad_algo_.Init(std::move(arg)); - } - - void CreateStepNet() { - LOG(INFO) << "create variable step_net"; - Variable* var = scope_.NewVar("step_net"); - auto net = var->GetMutable(); - // TODO(qingqing) modify backward op create for RNNOp unit test - // and the unit test will be removed to Python. - // net->AddOp(OpRegistry::CreateOp("mul", {"X", {"rnn/h_pre", "rnn/w", - // "rnn/s_grad"}}, {"Y", {"rnn/h_pre_grad", "rnn/w_grad"}}, {})); - - // net->AddOp(OpRegistry::CreateOp("add_two", {"X", {"rnn/h_grad"}}, - // {"Y", {"rnn/x_grad"}}, {"Out", "rnn/s_grad"}}, {})); - net->CompleteAddOp(); - } - - void SegmentInputs() { - LOG(INFO) << "segment inputs"; - std::vector inlinks = {"x"}; - std::vector inlinks_alias = {"rnn/x"}; - - rnn::Link inlink; - inlink.external = "x"; - inlink.internal = "rnn/x"; - auto step_scopes = - scope_.FindVar("step_scopes")->GetMutable>(); - rnn::SegmentInputs(*step_scopes, std::vector{inlink}, 10, - true /*infer_shape_mode*/); - } - - void LinkeMemories() { - LOG(INFO) << "link memories"; - rnn::MemoryAttr mem_attr; - mem_attr.pre_var = "rnn/h_pre"; - mem_attr.var = "rnn/h"; - mem_attr.boot_var = "boot_h"; - std::vector memories; - memories.push_back(mem_attr); - auto step_scopes = - scope_.FindVar("step_scopes")->GetMutable>(); - for (int i = 1; i < 10; ++i) { - rnn::LinkMemories(*step_scopes, memories, i, -1, - true /*infer_shape_mode*/); - } - } - - Scope scope_; - RecurrentGradientAlgorithm rnn_grad_algo_; -}; - -// TEST_F(RecurrentGradientAlgorithmTest, Run) { -// platform::CPUDeviceContext ctx; -// rnn_grad_algo_.Run(scope_, ctx); -// } - -} // namespace operators -} // namespace paddle - -TEST(RecurrentOp, LinkMemories) { - using namespace paddle::framework; - using namespace paddle::platform; - using namespace paddle::operators; - - // create and init step scopes - size_t len = 10; - std::vector step_scopes; - for (size_t i = 0; i < len; ++i) { - auto scope = new Scope(); - scope->NewVar("pre_h"); - auto tensor = scope->NewVar("h")->GetMutable(); - float* data = tensor->mutable_data({15, 20}, CPUPlace()); - for (size_t j = 0; j < 15 * 20; ++j) { - data[j] = rand() * (1. / (double)RAND_MAX); - } - step_scopes.push_back(scope); - } - - // create MemoryAttr - rnn::MemoryAttr mem_attr; - mem_attr.pre_var = "pre_h"; - mem_attr.var = "h"; - mem_attr.boot_var = "boot_h"; - std::vector memories; - memories.push_back(mem_attr); - - for (size_t i = 1; i < len; ++i) { - rnn::LinkMemories(step_scopes, memories, i, -1, false - /*infer_shape_mode*/); - } - // check - for (size_t i = 0; i < len - 1; ++i) { - const float* a = - step_scopes[i]->FindVar("h")->GetMutable()->data(); - const float* b = step_scopes[i + 1] - ->FindVar("pre_h") - ->GetMutable() - ->data(); - for (size_t j = 0; j < 15 * 20; ++j) { - ASSERT_FLOAT_EQ(a[j], b[j]); - } - } - - for (int i = len - 2; i >= 0; --i) { - rnn::LinkMemories(step_scopes, memories, i, 1, false - /*infer_shape_mode*/); - } - // check - for (int i = len - 2; i >= 0; --i) { - const float* a = - step_scopes[i]->FindVar("pre_h")->GetMutable()->data(); - const float* b = - step_scopes[i + 1]->FindVar("h")->GetMutable()->data(); - for (size_t j = 0; j < 15 * 20; ++j) { - ASSERT_FLOAT_EQ(a[j], b[j]); - } - } - - for (auto s : step_scopes) { - delete s; - } -} - -USE_OP(add_two); -USE_OP(mul); -USE_OP_ITSELF(recurrent_op); diff --git a/paddle/operators/rnn/recurrent_op_utils.cc b/paddle/operators/rnn/recurrent_op_utils.cc index 7e4770630ed2a..a9b65c30f2555 100644 --- a/paddle/operators/rnn/recurrent_op_utils.cc +++ b/paddle/operators/rnn/recurrent_op_utils.cc @@ -106,7 +106,6 @@ void LinkMemories(const std::vector& scopes, void InitArgument(const ArgumentName& name, Argument* arg, const framework::OperatorBase& op) { - arg->step_net = op.Input(name.step_net); arg->step_scopes = op.Output(name.step_scopes); auto inlinks = op.Inputs(name.inlinks); diff --git a/python/paddle/v2/framework/op.py b/python/paddle/v2/framework/op.py index 904de08da4efa..6ac656321e72f 100644 --- a/python/paddle/v2/framework/op.py +++ b/python/paddle/v2/framework/op.py @@ -23,7 +23,7 @@ class OpDescCreationMethod(object): """ A Functor object to convert user input(use key word args) to OpDesc based on OpProto. - + :param op_proto: The OpProto object. :type op_proto: op_proto_pb2.OpProto """ @@ -177,4 +177,26 @@ def get_op_attr_names(self, type): return self.get_op_info(type).attrs +class __RecurrentOp__(object): + __proto__ = None + type = 'recurrent_op' + + def __init__(self): + # cache recurrent_op's proto + if self.__proto__ is None: + for op_proto in get_all_op_protos(): + if op_proto.type == self.type: + self.__proto__ = op_proto + + def __call__(self, *args, **kwargs): + if self.type not in args and 'type' not in kwargs: + kwargs['type'] = self.type + # create proto + create_method = OpDescCreationMethod(self.__proto__) + proto = create_method(*args, **kwargs) + # create rnnop + return core.RecurrentOp.create(proto.SerializeToString()) + + Operator = OperatorFactory() # Default global factory +RecurrentOp = __RecurrentOp__() diff --git a/python/paddle/v2/framework/tests/test_recurrent_op.py b/python/paddle/v2/framework/tests/test_recurrent_op.py index 0db66cc4e181f..3d4a34d8d713f 100644 --- a/python/paddle/v2/framework/tests/test_recurrent_op.py +++ b/python/paddle/v2/framework/tests/test_recurrent_op.py @@ -2,7 +2,7 @@ import paddle.v2.framework.core as core import unittest import numpy as np -from paddle.v2.framework.op import Operator +from paddle.v2.framework.op import Operator, RecurrentOp def py_sigmoid(x): @@ -98,11 +98,11 @@ def setUp(self): def forward(self): self.scope = core.Scope() self.create_global_variables() + self.create_rnn_op() self.create_step_net() - rnn_op = self.create_rnn_op() ctx = core.DeviceContext.create(core.CPUPlace()) - rnn_op.infer_shape(self.scope) - rnn_op.run(self.scope, ctx) + self.rnnop.infer_shape(self.scope) + self.rnnop.run(self.scope, ctx) return np.array(self.scope.find_var("h").get_tensor()) def create_global_variables(self): @@ -128,8 +128,7 @@ def create_global_variables(self): def create_rnn_op(self): # create RNNOp - rnnop = Operator( - "recurrent_op", + self.rnnop = RecurrentOp( # inputs inlinks=["x"], boot_memories=["h_boot"], @@ -142,14 +141,9 @@ def create_rnn_op(self): outlink_alias=["h@alias"], pre_memories=["h@pre"], memories=["h@alias"]) - return rnnop def create_step_net(self): - var = self.scope.new_var("stepnet") - stepnet = var.get_net() - - # x_fc_op = Operator("fc", X="x@alias", W="W", Y="Wx") - # h_fc_op = Operator("fc", X="h@pre", W="U", Y="Uh") + stepnet = core.Net.create() x_fc_op = Operator("mul", X="x@alias", Y="W", Out="Wx") h_fc_op = Operator("mul", X="h@pre", Y="U", Out="Uh") sum_op = Operator("add_two", X="Wx", Y="Uh", Out="sum") @@ -158,6 +152,7 @@ def create_step_net(self): for op in [x_fc_op, h_fc_op, sum_op, sig_op]: stepnet.add_op(op) stepnet.complete_add_op(True) + self.rnnop.set_stepnet(stepnet) def test_forward(self): print 'test recurrent op forward'