-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2629 from jacquesqiao/scope-impl
basic scope implementation #2629 * scope-impl * refine GetOrCreateVariable * refine code of scope with style check * refine code of operator * clang format * move code to scope.h, remove scope.cc * add Create for scope * merge CreateVar and GetOrCreateVar * do not use default argument in Scope * update design doc
- Loading branch information
Showing
4 changed files
with
161 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
# ddim lib | ||
cc_library(ddim SRCS ddim.cc) | ||
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) | ||
nv_test(dim_test SRCS dim_test.cu DEPS ddim) | ||
cc_test(variable_test SRCS variable_test.cc) | ||
cc_test(scope_test SRCS scope_test.cc) | ||
cc_test(enforce_test SRCS enforce_test.cc) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include <string> | ||
#include <unordered_map> | ||
#include <vector> | ||
|
||
#include "paddle/framework/variable.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
/** | ||
* @brief Scope that manage all variables. | ||
* | ||
* Scope is an association of a name to Variable. All variables belong to | ||
* Scope. You need to specify a scope to run a Net, i.e., `net.Run(&scope)`. | ||
* One net can run in different scopes and update different variable in the | ||
* scope. | ||
*/ | ||
class Scope { | ||
public: | ||
/** | ||
* @brief Initialize s Scope without parent. | ||
*/ | ||
Scope() {} | ||
|
||
/** | ||
* @brief Initialize a Scope with parent. | ||
*/ | ||
explicit Scope(const std::shared_ptr<Scope>& parent) : parent_(parent) {} | ||
|
||
/** | ||
* @brief Create Variable | ||
* | ||
* Create Variable in this Scope. Return the exist one if Variable already | ||
* been created. | ||
*/ | ||
Variable* CreateVariable(const std::string& name) { | ||
auto var = GetVariable(name); | ||
if (var) { | ||
return var; | ||
} else { | ||
vars_[name] = std::unique_ptr<Variable>(new Variable()); | ||
return GetVariable(name); | ||
} | ||
} | ||
|
||
/** | ||
* @brief Get Variable. | ||
* | ||
* Get Variable from this Scope, this function will recursive find Variable | ||
* from it's parent scope. Return nullptr if not found. | ||
*/ | ||
Variable* GetVariable(const std::string& name) const { | ||
auto it = vars_.find(name); | ||
if (it != vars_.end()) { | ||
return it->second.get(); | ||
} else if (parent_ != nullptr) { | ||
return parent_->GetVariable(name); | ||
} else { | ||
return nullptr; | ||
} | ||
} | ||
|
||
/** | ||
* @brief If this scope has a Var named name. | ||
* | ||
* Find if there is a Variable in this scope and it's parent scope | ||
*/ | ||
bool HasVariable(const std::string& name) const { | ||
return (vars_.find(name) != vars_.end() || | ||
(parent_ && parent_->HasVariable(name))); | ||
} | ||
|
||
private: | ||
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_; | ||
std::shared_ptr<Scope> parent_{nullptr}; | ||
}; | ||
|
||
} // namespace framework | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/framework/scope.h" | ||
#include "gtest/gtest.h" | ||
|
||
TEST(Scope, Create) { | ||
using paddle::framework::Scope; | ||
using paddle::framework::Variable; | ||
|
||
auto scope = std::make_shared<Scope>(); | ||
|
||
Variable* var0 = scope->CreateVariable(""); | ||
EXPECT_NE(var0, nullptr); | ||
|
||
/// GetVariable will return nullptr if not exist. | ||
Variable* var1 = scope->GetVariable("a"); | ||
EXPECT_EQ(var1, nullptr); | ||
|
||
/// CreateVariable will return one. | ||
Variable* var2 = scope->CreateVariable("a"); | ||
EXPECT_NE(var2, nullptr); | ||
|
||
/// Get the created variable. | ||
Variable* var3 = scope->GetVariable("a"); | ||
EXPECT_EQ(var2, var3); | ||
|
||
/// CreateVariable will just return the variable if it's | ||
/// already exist. | ||
Variable* var4 = scope->CreateVariable("a"); | ||
EXPECT_EQ(var4, var2); | ||
} | ||
|
||
TEST(Scope, Parent) { | ||
using paddle::framework::Scope; | ||
using paddle::framework::Variable; | ||
|
||
auto parent_scope = std::make_shared<Scope>(); | ||
auto scope = std::make_shared<Scope>(parent_scope); | ||
|
||
Variable* var0 = parent_scope->CreateVariable("a"); | ||
EXPECT_NE(var0, nullptr); | ||
|
||
/// GetVariable will get Variable from parent scope if exist. | ||
Variable* var1 = scope->GetVariable("a"); | ||
EXPECT_EQ(var0, var1); | ||
} |