Skip to content

Commit

Permalink
Merge pull request #2629 from jacquesqiao/scope-impl
Browse files Browse the repository at this point in the history
basic scope implementation #2629

* scope-impl

* refine GetOrCreateVariable

* refine code of scope with style check

* refine code of operator

* clang format

* move code to scope.h, remove scope.cc

* add Create for scope

* merge CreateVar and GetOrCreateVar

* do not use default argument in Scope

* update design doc
  • Loading branch information
jacquesqiao authored Jun 29, 2017
2 parents aaee28b + 3c925fe commit 590795d
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 6 deletions.
12 changes: 6 additions & 6 deletions doc/design/scope.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Scope {
const Variable* GetVariable(const std::string& name) const;

private:
std::unordered_map<std::string, std::unique_ptr<Vairable>> vars_;
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
};
```
Expand All @@ -59,9 +59,9 @@ class Scope {
Scope(const std::shared_ptr<Scope>& scope): parent_(scope) {}
Variable* GetVariable(const std::string& name) const {
Variable* var = GetVarLocally(name);
if (var != nullptr) {
return var;
auto it = vars_.find(name);
if (it != vars_.end()) {
return it->second.get();
} else if (parent_ != nullptr) {
return parent_->GetVariable(name);
} else {
Expand Down Expand Up @@ -97,8 +97,8 @@ class Scope {
// return nullptr if not found.
Variable* GetVariable(const std::string& name) const;

// return Error if already contains same name variable.
Error CreateVariable(const std::string& name);
// return if already contains same name variable.
Variable* CreateVariable(const std::string& name);

private:
std::shared_ptr<Scope> parent_;
Expand Down
2 changes: 2 additions & 0 deletions paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# ddim lib
cc_library(ddim SRCS ddim.cc)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_test(variable_test SRCS variable_test.cc)
cc_test(scope_test SRCS scope_test.cc)
cc_test(enforce_test SRCS enforce_test.cc)
95 changes: 95 additions & 0 deletions paddle/framework/scope.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/framework/variable.h"

namespace paddle {
namespace framework {

/**
* @brief Scope that manage all variables.
*
* Scope is an association of a name to Variable. All variables belong to
* Scope. You need to specify a scope to run a Net, i.e., `net.Run(&scope)`.
* One net can run in different scopes and update different variable in the
* scope.
*/
class Scope {
public:
/**
* @brief Initialize s Scope without parent.
*/
Scope() {}

/**
* @brief Initialize a Scope with parent.
*/
explicit Scope(const std::shared_ptr<Scope>& parent) : parent_(parent) {}

/**
* @brief Create Variable
*
* Create Variable in this Scope. Return the exist one if Variable already
* been created.
*/
Variable* CreateVariable(const std::string& name) {
auto var = GetVariable(name);
if (var) {
return var;
} else {
vars_[name] = std::unique_ptr<Variable>(new Variable());
return GetVariable(name);
}
}

/**
* @brief Get Variable.
*
* Get Variable from this Scope, this function will recursive find Variable
* from it's parent scope. Return nullptr if not found.
*/
Variable* GetVariable(const std::string& name) const {
auto it = vars_.find(name);
if (it != vars_.end()) {
return it->second.get();
} else if (parent_ != nullptr) {
return parent_->GetVariable(name);
} else {
return nullptr;
}
}

/**
* @brief If this scope has a Var named name.
*
* Find if there is a Variable in this scope and it's parent scope
*/
bool HasVariable(const std::string& name) const {
return (vars_.find(name) != vars_.end() ||
(parent_ && parent_->HasVariable(name)));
}

private:
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
std::shared_ptr<Scope> parent_{nullptr};
};

} // namespace framework
} // namespace paddle
58 changes: 58 additions & 0 deletions paddle/framework/scope_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/scope.h"
#include "gtest/gtest.h"

TEST(Scope, Create) {
using paddle::framework::Scope;
using paddle::framework::Variable;

auto scope = std::make_shared<Scope>();

Variable* var0 = scope->CreateVariable("");
EXPECT_NE(var0, nullptr);

/// GetVariable will return nullptr if not exist.
Variable* var1 = scope->GetVariable("a");
EXPECT_EQ(var1, nullptr);

/// CreateVariable will return one.
Variable* var2 = scope->CreateVariable("a");
EXPECT_NE(var2, nullptr);

/// Get the created variable.
Variable* var3 = scope->GetVariable("a");
EXPECT_EQ(var2, var3);

/// CreateVariable will just return the variable if it's
/// already exist.
Variable* var4 = scope->CreateVariable("a");
EXPECT_EQ(var4, var2);
}

TEST(Scope, Parent) {
using paddle::framework::Scope;
using paddle::framework::Variable;

auto parent_scope = std::make_shared<Scope>();
auto scope = std::make_shared<Scope>(parent_scope);

Variable* var0 = parent_scope->CreateVariable("a");
EXPECT_NE(var0, nullptr);

/// GetVariable will get Variable from parent scope if exist.
Variable* var1 = scope->GetVariable("a");
EXPECT_EQ(var0, var1);
}

0 comments on commit 590795d

Please sign in to comment.