Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge with lasted develop branch. Optimizer lib2 #2386

Merged
merged 35 commits into from
Jun 20, 2017
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
62cd5c7
"failed to resolve conflict. apply to HEAD"
dzhwinter Jun 4, 2017
3158efe
"move cmake scripts too"
dzhwinter Jun 4, 2017
5b8a0c5
"optimizer remove init create with proto"
dzhwinter Jun 5, 2017
8610ba1
"remove get config proto"
dzhwinter Jun 5, 2017
b4aa0ec
"modify update interface"
dzhwinter Jun 5, 2017
26e9c4e
"add vector alias to make name clear"
dzhwinter Jun 5, 2017
5f9cd8c
"rename test file name"
dzhwinter Jun 5, 2017
b9d024e
"remove useless test file"
dzhwinter Jun 5, 2017
5ab958b
"change size_t type to avoid warning"
dzhwinter Jun 5, 2017
fd8c510
"format name with google style"
dzhwinter Jun 6, 2017
3b1294a
"add checkpoint interface: set state, get state"
dzhwinter Jun 6, 2017
81cad37
"remove comments"
dzhwinter Jun 6, 2017
beb2697
"change header guard to pragma"
dzhwinter Jun 6, 2017
5a1e678
"update macro and fix some part"
dzhwinter Jun 6, 2017
bc26df7
"polish code style and update based review comment"
dzhwinter Jun 7, 2017
b9cb0f2
"update marco"
dzhwinter Jun 7, 2017
6cbbc2e
"add comments"
dzhwinter Jun 7, 2017
f5ff283
"fix comment"
dzhwinter Jun 7, 2017
e456796
"update with comment"
dzhwinter Jun 9, 2017
33b4dee
"update serialization part"
dzhwinter Jun 9, 2017
0fc4201
"update interface"
dzhwinter Jun 9, 2017
b7e68e0
"serialization modify"
dzhwinter Jun 11, 2017
b72e8aa
"seperate serialization proto state"
dzhwinter Jun 13, 2017
1814fc2
"fix lr_policy serialization"
dzhwinter Jun 14, 2017
e148bc1
"remove unused tensor line"
dzhwinter Jun 14, 2017
a46f3fc
"fix double release tensor buffer error."
dzhwinter Jun 14, 2017
df5bc78
"fix tensor shared_ptr"
dzhwinter Jun 15, 2017
65d9e33
"modify config name"
dzhwinter Jun 19, 2017
ec65fa8
"protobuf required to optional"
dzhwinter Jun 19, 2017
baef96e
Merge branch 'develop' into optimizer_lib2
dzhwinter Jun 19, 2017
99849cf
rename Tensor.h
dzhwinter Jun 19, 2017
72b6b26
"ci formatter"
dzhwinter Jun 19, 2017
03884f0
formatter
dzhwinter Jun 19, 2017
a166e52
"formatter in docker"
dzhwinter Jun 19, 2017
33ddc89
formatter in docker
dzhwinter Jun 19, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_subdirectory(gserver)
add_subdirectory(pserver)
add_subdirectory(trainer)
add_subdirectory(scripts)
add_subdirectory(optimizer)

# Do not build go directory until go cmake is working smoothly.
# if(CMAKE_Go_COMPILER)
Expand Down
28 changes: 28 additions & 0 deletions paddle/optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
include_directories(${CMAKE_CURRENT_BINARY_DIR})

set(OPITMIZER_SRCS
adadelta_optimizer.cc
adagrad_optimizer.cc
adam_optimizer.cc
optimizer.cc
parameter_optimizer.cc
sgd_optmizer.cc
)

set(OPITMIZER_Headers
adadelta_optimizer.h
adagrad_optimizer.h
adam_optimizer.h
lr_policy.h
optimizer.h
parameter_optimizer.h
sgd_optimizer.h
Tensor.h
)

add_library(optimizer STATIC ${OPITMIZER_SRCS})
add_dependencies(optimizer gen_proto_cpp)

add_simple_unittest(Tensor_test)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use all low case for filenames (tensor_test.cpp instead of Tensor_test.cpp). Some filesystem is not case sensitive.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix done

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix done

add_simple_unittest(parameter_optimizer_test)
add_dependencies(parameter_optimizer_test optimizer)
49 changes: 49 additions & 0 deletions paddle/optimizer/Tensor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#ifndef PADDLE_OPTIMIZER_TENSOR_H_
Copy link
Member

@jacquesqiao jacquesqiao Jun 6, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PaddlePaddle use #pragma once

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will change that and follow the google c++ style.
https://google.github.io/styleguide/cppguide.html#The__define_Guard

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a formal discussion, we decided to use #pragma once(https://github.com/PaddlePaddle/cpp-primer-digest). So we can first change the digest then change our code style.

#define PADDLE_OPTIMIZER_TENSOR_H_
/**
* @brief tensor used by optimizer
*/

#include <string.h>
#include "paddle/utils/Common.h"
#include "paddle/utils/Logging.h"

namespace paddle {
namespace optimizer {

template <class T>
class TensorT {
public:
TensorT(size_t h, size_t w, T* data) : height_(h), width_(w), data_(data_) {}
TensorT(T* data, int size) : height_(1), width_(size), data_(data) {}
TensorT(const TensorT& t)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain to me what does ":" do here? Sorry I am not too familiar, and don't know what's the keyword to search for.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is initializer in c++, which is the idiomatic way in c++ initializes parameter.

please check here for detail. http://en.cppreference.com/w/cpp/language/direct_initialization

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

括号后面的":",这个是c++里的初始化手段,和构造函数还不是同一个概念。初始化列表和构造函数的关系类比python的__new__和__init__的关系。初始化列表会在构造函数前完成(就是花括号里的东西)。
中文叫 初始化列表,英文叫 constructor initializer list。
1、初始化列表在任何函数执行之前完成
2、初始化列表中的参数赋值顺序是由成员声明顺序决定
并且对于非POD类型的成员具有限制:
https://stackoverflow.com/questions/5816218/difference-between-initializer-and-default-initializer-list-in-c
https://stackoverflow.com/questions/9903248/initializing-fields-in-constructor-initializer-list-vs-constructor-body

一般推荐非静态成员都使用该方法初始化

: TensorT(1, t.size(), 0, t.get_buffer(), false, false) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess here is copy constructor, and what here is doing is that it created a new tensor, copying from the old tensor. And they shared the same buffer.
Seems when the two tensors gets destroyed, they will try to destroy the same buffer twice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 fix done.

TensorT& operator=(const TensorT& t) {
this->width_ = t.size();
this->data_ = t.get_buffer();
}
T* get_buffer() { return this->data_; }
T& operator[](const size_t idx) {
CHECK(idx >= 0 && idx < this->width_) << "out of index range";
return data_[idx];
}
T& operator[](const size_t idx) const {
CHECK(idx >= 0 && idx < this->width_) << "out of index range";
return data_[idx];
}
// TODO: replace with tensorshape
size_t size() const { return this->width_; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since height is already a member variable, I think you implemented a 2-d tensor, so here should be return this->width_ * this->height;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be. fix done

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be. fixed done.


protected:
size_t height_;
size_t width_;
T* data_;
};

// TODO(zhihong): design problem of dynamic datatype, need to fix it
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that when porting "majel" to PaddlePaddle, we already included boost/variant.hpp for the "single value multiple type" container.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, 👍. Either we can wait for their majel port job finish, or implement another one with typeid reflection. It is a follow-up question.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.

typedef TensorT<real> Tensor;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not know whether real means float or double. Doing a Google search does not answer my question. Maybe for clarity, we need to change all real in this PR to float? (since currently we are working with 32 bit floating point parameters.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

real type is a macro type in the whole project. https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/utils/Common.h#L30
which is determined when the project is compiled. It is not a good choice since our type is passed by caller's data type, which is determined at runtime. Need to be fixed.

Copy link
Contributor Author

@dzhwinter dzhwinter Jun 6, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me explain this part more clear.
In the last PR, I just implement Tensor with template class, but it is the wrong way. Because template class will be lead to some specific types in compile period( explain the reason in a meeting will be better).
In fact, tensor should represent some sort of consistent memory, and support dynamic type in runtime
e.g. https://github.com/caffe2/caffe2/blob/master/caffe2/core/tensor.h#L643
just like an enhanced any type in boost library, we can merge this PR, and enrich tensor library in the future. At this moment, we only support compile period datatype here.

Copy link
Contributor

@helinwang helinwang Jun 6, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that's fine. Can you use float instead of real? Because parameter is supposed to be a general data store, should not be dependent on how real is defined on Paddle.

After this PR we need to work on another PR to add tensors of required data types.

MATCH_ENUM_TYPE(int32_t, PADDLE_ELEMENT_TYPE_INT32);
MATCH_ENUM_TYPE(uint32_t, PADDLE_ELEMENT_TYPE_UINT32);
MATCH_ENUM_TYPE(int64_t, PADDLE_ELEMENT_TYPE_INT64);
MATCH_ENUM_TYPE(uint64_t, PADDLE_ELEMENT_TYPE_UINT64);
// only below is implemented, we need to implement other types in a follow up PR.
MATCH_ENUM_TYPE(float, PADDLE_ELEMENT_TYPE_FLOAT32);
MATCH_ENUM_TYPE(double, PADDLE_ELEMENT_TYPE_FLOAT64);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, fix done.


} // namespace optimizer
} // namespace paddle

#endif
21 changes: 21 additions & 0 deletions paddle/optimizer/Tensor_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#include "Tensor.h"
#include <iostream>
Copy link
Member

@jacquesqiao jacquesqiao Jun 7, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should add Copyright message to these new files?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CopyRights declaration can be added when we merge these modules together. I noticed that go module part all need to add that header section.

#include "gtest/gtest.h"

using namespace paddle;
using namespace paddle::optimizer;

TEST(Tensor, indexer) {
real* ptr = new real[3];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to release the ptr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed done.

Tensor t(ptr, 3);
for (auto i = 0; i < t.size(); ++i) {
t[i] = i;
}
ASSERT_EQ(t[2], 2);
ASSERT_EQ(t[1], 1);
}

int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
38 changes: 38 additions & 0 deletions paddle/optimizer/adadelta_optimizer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include "adadelta_optimizer.h"
#include <algorithm>
#include <cmath>

namespace paddle {
namespace optimizer {

void AdadeltaOptimizer::set_weight(Tensor* p) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the content of p is no used in function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need Google style function names. Please replace all C++ function names with CamelCase.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://google.github.io/styleguide/cppguide.html#Function_Names

Accessors and mutators (get and set functions) may be named like variables. These often correspond to actual member variables, but this is not required. For example, int count() and void set_count(int count).
since this function modify the parameter member of the optimizer, in my mind, it is an accessor?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, sorry!

size_t size = p->size();
real* gptr = new real[size];
accum_gradient = new Tensor(gptr, size);
real* dptr = new real[size];
accum_delta = new Tensor(dptr, size);
real* dptr_current = new real[size];
update_delta = new Tensor(dptr_current, size);
}

void AdadeltaOptimizer::update(const Tensor* gradient) {
num_sample_passed += 1;
double learning_rate = lr_policy->get_learning_rate(num_sample_passed);
Tensor& param = *parameter_;
const Tensor& grad = *gradient;
Tensor& accum_g = *accum_gradient;
Tensor& accum_d = *accum_delta;
Tensor& update_d = *update_delta;
for (size_t i = 0; i < param.size(); ++i) {
accum_g[i] = rho * accum_g[i] + (1.0 - rho) * grad[i] * grad[i];

update_d[i] = std::sqrt(accum_d[i] + epsilon) /
std::sqrt(accum_g[i] + epsilon) * grad[i];

accum_d[i] = rho * accum_d[i] + (1.0 - rho) * update_d[i] * update_d[i];

param[i] -= learning_rate * update_d[i] + learning_rate * decay * param[i];
}
}
} // namespace optimizer
} // namespace paddle
39 changes: 39 additions & 0 deletions paddle/optimizer/adadelta_optimizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#ifndef PADDLE_ADADELTA_OPTIMIZER_H_
#define PADDLE_ADADELTA_OPTIMIZER_H_

#include "parameter_optimizer.h"

namespace paddle {
namespace optimizer {

class AdadeltaOptimizer : public ParameterOptimizer {
public:
using ParameterOptimizer::parameter_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably parameter_, num_sample_passed, lr_policy should be private?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These members will be accessed by the derived class, private keyword will forbidden that.
Actually here should be protected, but I found that Paddle other class prefer to define a classbase with keyword of public. Just leave them same way with the project

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not use public variable if it's not intended for public. It breaks encapsulation. Protected is fine for me. I understand being consistent is important, but we should not be consistent with poor design, otherwise how can we improve code quality?

using ParameterOptimizer::num_sample_passed;
using ParameterOptimizer::lr_policy;

AdadeltaOptimizer(double rho, double epsilon, double decay, BaseLr *lr)
: ParameterOptimizer(lr), rho(rho), epsilon(epsilon), decay(decay) {}
~AdadeltaOptimizer() {
if (accum_gradient) delete accum_gradient;
if (accum_delta) delete accum_delta;
if (update_delta) delete update_delta;
}
void update(const Tensor *gradient);
void set_weight(Tensor *p);
real *get_weight() const;

private:
Tensor *accum_gradient;
Tensor *accum_delta;
Tensor *update_delta;

double rho;
double epsilon;
double decay;
};

} // namespace optimizer
} // namespace paddle

#endif
28 changes: 28 additions & 0 deletions paddle/optimizer/adagrad_optimizer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include <cmath>

#include "adagrad_optimizer.h"

namespace paddle {
namespace optimizer {

void AdagradOptimizer::set_weight(Tensor* p) {
size_t size = p->size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems the content of p is not used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a horrible mistake....
fix done.

real* gptr = new real[size];
accum_gradient = new Tensor(gptr, size);
}

void AdagradOptimizer::update(const Tensor* gradient) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to follow Google C++ code style for C++ function names: https://google.github.io/styleguide/cppguide.html#Function_Names

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix done.

num_sample_passed += 1;
double learning_rate = lr_policy->get_learning_rate(num_sample_passed);
Tensor& param = *parameter_;
const Tensor& grad = *gradient;
Tensor& accum_g = *accum_gradient;
for (size_t i = 0; i < param.size(); ++i) {
accum_g[i] += grad[i] * grad[i];
param[i] += learning_rate * grad[i] / std::sqrt(accum_g[i] + epsilon) +
learning_rate * decay * param[i];
}
}

} // namespace optimizer
} // namespace paddle
29 changes: 29 additions & 0 deletions paddle/optimizer/adagrad_optimizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef PADDLE_ADAGRAD_OPTIMIZER_H_
#define PADDLE_ADAGRAD_OPTIMIZER_H_

#include "parameter_optimizer.h"

namespace paddle {
namespace optimizer {

class AdagradOptimizer : public ParameterOptimizer {
public:
AdagradOptimizer(double epsilon, double decay, BaseLr *lr)
: ParameterOptimizer(lr), epsilon(epsilon), decay(decay) {}
~AdagradOptimizer() {
if (accum_gradient) delete accum_gradient;
}
void update(const Tensor *gradient);
void set_weight(Tensor *p);
real *get_weight() const;

private:
Tensor *accum_gradient;
double epsilon;
double decay;
};

} // namespace optimizer
} // namespace paddle

#endif
33 changes: 33 additions & 0 deletions paddle/optimizer/adam_optimizer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include "adam_optimizer.h"
#include <cmath>

namespace paddle {
namespace optimizer {

void AdamOptimizer::set_weight(Tensor *p) {
size_t size = p->size();
real *mptr = new real[size];
momentums_ = new Tensor(mptr, size);
real *vptr = new real[size];
velocitys_ = new Tensor(vptr, size);
}

void AdamOptimizer::update(const Tensor *gradient) {
num_sample_passed += 1;
double learning_rate = lr_policy->get_learning_rate(num_sample_passed);
double coef1 = 1.0 - std::pow(beta_1, num_sample_passed);
double coef2 = 1.0 - std::pow(beta_2, num_sample_passed);
learning_rate *= std::sqrt(coef2) / coef1;
Tensor &param = *parameter_;
const Tensor &grad = *gradient;
Tensor &m = *momentums_;
Tensor &v = *velocitys_;
for (size_t i = 0; i < param.size(); ++i) {
m[i] = beta_1 * m[i] + (1.0 - beta_1) * grad[i];
v[i] = beta_2 * v[i] + (1.0 - beta_2) * grad[i] * grad[i];
param[i] -=
learning_rate * (m[i] / std::sqrt(v[i] + epsilon) + decay * param[i]);
}
}
} // namespace optimizer
} // namespace paddle
37 changes: 37 additions & 0 deletions paddle/optimizer/adam_optimizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#ifndef PADDLE_ADAM_OPTIMIZER_H_
#define PADDLE_ADAM_OPTIMIZER_H_

#include "parameter_optimizer.h"

namespace paddle {
namespace optimizer {

class AdamOptimizer : public ParameterOptimizer {
public:
AdamOptimizer(
double beta_1, double beta_2, double epsilon, double decay, BaseLr *lr)
: ParameterOptimizer(lr),
beta_1(beta_1),
beta_2(beta_2),
epsilon(epsilon),
decay(decay) {}
~AdamOptimizer() {
if (momentums_) delete momentums_;
if (velocitys_) delete velocitys_;
}
void update(const Tensor *gradient);
void set_weight(Tensor *p);
real *get_weight() const;

private:
Tensor *momentums_;
Tensor *velocitys_;
double beta_1;
double beta_2;
double epsilon;
double decay;
};

} // namespace optimizer
} // namespace paddle
#endif
45 changes: 45 additions & 0 deletions paddle/optimizer/lr_policy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#ifndef PADDLE_OPTIMIZER_LR_POLICY_H_
#define PADDLE_OPTIMIZER_LR_POLICY_H_

#include <algorithm>
#include "OptimizerConfig.pb.h"

namespace paddle {
namespace optimizer {

class BaseLr {
public:
BaseLr(double lr) : learning_rate(lr) {}
virtual ~BaseLr() {}
virtual double get_learning_rate(const uint64_t num_sample_passed) = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_learning_rate()已经是pure virtual,没有实现。这里会再有learning_rate这个成员变量会有一些奇怪:按理说learning_rate应该和get_learning_rate()的实现在同一个class。
我理解这里直接写一个interface可能更合适:https://stackoverflow.com/a/318137/852385

另外既然这个文件已经叫lr_policy了,这个class是不是直接叫LrPolicy更合适?

最后,Google c++ coding style函数需要CamelCase:https://google.github.io/styleguide/cppguide.html#Function_Names

是不是这样会更清晰:

class LrPolicy {
public:
  virtual ~LrPolicy() {}
  virtual double LearningRate(const uint64_t num_sample_passed) = 0;
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it is a base class, derived classes named with the convention of XXXLr. fix done.


protected:
double learning_rate;
};

// constant learning rate policy
class ConstLr final : public BaseLr {
public:
ConstLr(double lr) : BaseLr(lr){};
double get_learning_rate(const uint64_t num_sample_passed) {
return learning_rate;
}
};

class LinearLr final : public BaseLr {
public:
LinearLr(double lr, double lr_decay_a, double lr_decay_b)
: BaseLr(lr), lr_decay_a(lr_decay_a), lr_decay_b(lr_decay_b) {}
double get_learning_rate(const uint64_t num_sample_passed) {
return std::max(learning_rate - lr_decay_a * num_sample_passed, lr_decay_b);
}

private:
double lr_decay_a;
double lr_decay_b;
};

} // namespace optimizer
} // namespace paddle

#endif
Loading