From 11afd92a1809704de8915b8791761520290aac89 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Fri, 1 Dec 2017 13:52:11 +0800 Subject: [PATCH 1/2] Expose sigmoid_cross_entropy_with_logits Also, change the `labels` to `label` for api consistency --- .../sigmoid_cross_entropy_with_logits_op.cc | 24 +++++++-------- .../sigmoid_cross_entropy_with_logits_op.h | 6 ++-- python/paddle/v2/fluid/layers.py | 1 + python/paddle/v2/fluid/tests/test_layers.py | 10 +++++++ ...st_sigmoid_cross_entropy_with_logits_op.py | 29 +++++++++++-------- 5 files changed, 41 insertions(+), 29 deletions(-) diff --git a/paddle/operators/sigmoid_cross_entropy_with_logits_op.cc b/paddle/operators/sigmoid_cross_entropy_with_logits_op.cc index d9e40546523c6..782f4c79361b3 100644 --- a/paddle/operators/sigmoid_cross_entropy_with_logits_op.cc +++ b/paddle/operators/sigmoid_cross_entropy_with_logits_op.cc @@ -25,20 +25,19 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); - PADDLE_ENFORCE(ctx->HasInput("Labels"), - "Input(Labels) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null."); auto x_dims = ctx->GetInputDim("X"); - auto labels_dims = ctx->GetInputDim("Labels"); + auto labels_dims = ctx->GetInputDim("Label"); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(labels_dims.size(), 2, - "Input(Labels)'s rank should be 2."); + "Input(Label)'s rank should be 2."); PADDLE_ENFORCE_EQ(x_dims[0], labels_dims[0], - "The 1st dimension of Input(X) and Input(Labels) should " + "The 1st dimension of Input(X) and Input(Label) should " "be equal."); PADDLE_ENFORCE_EQ(x_dims[1], labels_dims[1], - "The 2nd dimension of Input(X) and Input(Labels) should " + "The 2nd dimension of Input(X) and Input(Label) should " "be equal."); ctx->SetOutputDim("Out", x_dims); @@ -53,26 +52,25 @@ class SigmoidCrossEntropyWithLogitsGradOp void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); - PADDLE_ENFORCE(ctx->HasInput("Labels"), - "Input(Labels) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) shoudl be not null."); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), "Output(X@GRAD) should be not null."); auto x_dims = ctx->GetInputDim("X"); - auto labels_dims = ctx->GetInputDim("Labels"); + auto labels_dims = ctx->GetInputDim("Label"); auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(labels_dims.size(), 2, - "Input(Labels)'s rank should be 2."); + "Input(Label)'s rank should be 2."); PADDLE_ENFORCE_EQ(dout_dims.size(), 2, "Input(Out@Grad)'s rank should be 2."); PADDLE_ENFORCE_EQ(x_dims[0], labels_dims[0], - "The 1st dimension of Input(X) and Input(Labels) should " + "The 1st dimension of Input(X) and Input(Label) should " "be equal."); PADDLE_ENFORCE_EQ(x_dims[1], labels_dims[1], - "The 2nd dimension of Input(X) and Input(Labels) should " + "The 2nd dimension of Input(X) and Input(Label) should " "be equal."); PADDLE_ENFORCE_EQ(x_dims[0], dout_dims[0], "The 1st dimension of Input(X) and Input(Out@Grad) " @@ -97,7 +95,7 @@ class SigmoidCrossEntropyWithLogitsOpMaker "This input is a tensor of logits computed by the previous " " operator. Logits are unscaled log probabilities given as " "log(p/(1-p))."); - AddInput("Labels", + AddInput("Label", "(Tensor, default Tensor), a 2-D tensor of the same type " "and shape as X. This input is a tensor of probabalistic labels " "for each logit"); diff --git a/paddle/operators/sigmoid_cross_entropy_with_logits_op.h b/paddle/operators/sigmoid_cross_entropy_with_logits_op.h index 41c619f181c87..2a9d9bbc77266 100644 --- a/paddle/operators/sigmoid_cross_entropy_with_logits_op.h +++ b/paddle/operators/sigmoid_cross_entropy_with_logits_op.h @@ -25,8 +25,7 @@ class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { const framework::Tensor *X = context.Input("X"); - const framework::Tensor *Labels = - context.Input("Labels"); + const framework::Tensor *Labels = context.Input("Label"); framework::Tensor *Out = context.Output("Out"); Out->mutable_data(context.GetPlace()); @@ -52,8 +51,7 @@ class SigmoidCrossEntropyWithLogitsGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { const framework::Tensor *X = context.Input("X"); - const framework::Tensor *Labels = - context.Input("Labels"); + const framework::Tensor *Labels = context.Input("Label"); const framework::Tensor *dOut = context.Input(framework::GradVarName("Out")); framework::Tensor *dX = diff --git a/python/paddle/v2/fluid/layers.py b/python/paddle/v2/fluid/layers.py index 5a977978bf811..e41bfae285a5b 100644 --- a/python/paddle/v2/fluid/layers.py +++ b/python/paddle/v2/fluid/layers.py @@ -403,6 +403,7 @@ def func(**kwargs): _create_op_func_('scale') _create_op_func_('reshape') _create_op_func_('transpose') +_create_op_func_('sigmoid_cross_entropy_with_logits') def cast(x, dtype, main_program=None): diff --git a/python/paddle/v2/fluid/tests/test_layers.py b/python/paddle/v2/fluid/tests/test_layers.py index 33b0e54f42afc..a9d9d369c7377 100644 --- a/python/paddle/v2/fluid/tests/test_layers.py +++ b/python/paddle/v2/fluid/tests/test_layers.py @@ -137,6 +137,16 @@ def test_linear_chain_crf(self): print(str(program)) + def test_sigmoid_cross_entropy(self): + program = Program() + with program_guard(program): + dat = layers.data(name='data', shape=[10], dtype='float32') + lbl = layers.data(name='label', shape=[10], dtype='float32') + self.assertIsNotNone( + layers.sigmoid_cross_entropy_with_logits( + x=dat, label=lbl)) + print(str(program)) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_sigmoid_cross_entropy_with_logits_op.py b/python/paddle/v2/fluid/tests/test_sigmoid_cross_entropy_with_logits_op.py index e53856b38aa5d..c42f578f72cb1 100644 --- a/python/paddle/v2/fluid/tests/test_sigmoid_cross_entropy_with_logits_op.py +++ b/python/paddle/v2/fluid/tests/test_sigmoid_cross_entropy_with_logits_op.py @@ -2,11 +2,12 @@ from op_test import OpTest from scipy.special import logit from scipy.special import expit +import unittest class TestSigmoidCrossEntropyWithLogitsOp1(OpTest): - '''Test sigmoid_cross_entropy_with_logit_op with binary labels - ''' + """Test sigmoid_cross_entropy_with_logit_op with binary label + """ def setUp(self): self.op_type = "sigmoid_cross_entropy_with_logits" @@ -16,16 +17,16 @@ def setUp(self): 'X': logit( np.random.uniform(0, 1, (batch_size, num_classes)) .astype("float32")), - 'Labels': np.random.randint(0, 2, (batch_size, num_classes)) + 'Label': np.random.randint(0, 2, (batch_size, num_classes)) .astype("float32") } # Fw Pass is implemented as elementwise sigmoid followed by # elementwise logistic loss - # Labels * -log(sigmoid(X)) + (1 - labels) * -log(1 - sigmoid(X)) + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) sigmoid_X = expit(self.inputs['X']) - term1 = self.inputs['Labels'] * np.log(sigmoid_X) - term2 = (1 - self.inputs['Labels']) * np.log(1 - sigmoid_X) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) self.outputs = {'Out': -term1 - term2} def test_check_output(self): @@ -36,8 +37,8 @@ def test_check_grad(self): class TestSigmoidCrossEntropyWithLogitsOp2(OpTest): - '''Test sigmoid_cross_entropy_with_logit_op with probabalistic labels - ''' + """Test sigmoid_cross_entropy_with_logit_op with probabalistic label + """ def setUp(self): self.op_type = "sigmoid_cross_entropy_with_logits" @@ -47,16 +48,16 @@ def setUp(self): 'X': logit( np.random.uniform(0, 1, (batch_size, num_classes)) .astype("float32")), - 'Labels': np.random.uniform(0, 1, (batch_size, num_classes)) + 'Label': np.random.uniform(0, 1, (batch_size, num_classes)) .astype("float32") } # Fw Pass is implemented as elementwise sigmoid followed by # elementwise logistic loss - # Labels * -log(sigmoid(X)) + (1 - labels) * -log(1 - sigmoid(X)) + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) sigmoid_X = expit(self.inputs['X']) - term1 = self.inputs['Labels'] * np.log(sigmoid_X) - term2 = (1 - self.inputs['Labels']) * np.log(1 - sigmoid_X) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) self.outputs = {'Out': -term1 - term2} def test_check_output(self): @@ -64,3 +65,7 @@ def test_check_output(self): def test_check_grad(self): self.check_grad(['X'], 'Out') + + +if __name__ == '__main__': + unittest.main() From 05db72ff15d1027c29a05fe851925b4481e8e9bb Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Fri, 1 Dec 2017 14:16:29 +0800 Subject: [PATCH 2/2] Very simple GAN based on pure FC layers --- python/paddle/v2/fluid/tests/demo/fc_gan.py | 157 ++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 python/paddle/v2/fluid/tests/demo/fc_gan.py diff --git a/python/paddle/v2/fluid/tests/demo/fc_gan.py b/python/paddle/v2/fluid/tests/demo/fc_gan.py new file mode 100644 index 0000000000000..cae959593e855 --- /dev/null +++ b/python/paddle/v2/fluid/tests/demo/fc_gan.py @@ -0,0 +1,157 @@ +import errno +import math +import os + +import matplotlib +import numpy + +import paddle.v2 as paddle +import paddle.v2.fluid as fluid + +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec + +NOISE_SIZE = 100 +NUM_PASS = 1000 +NUM_REAL_IMGS_IN_BATCH = 121 +NUM_TRAIN_TIMES_OF_DG = 3 +LEARNING_RATE = 2e-5 + + +def D(x): + hidden = fluid.layers.fc(input=x, + size=200, + act='relu', + param_attr='D.w1', + bias_attr='D.b1') + logits = fluid.layers.fc(input=hidden, + size=1, + act=None, + param_attr='D.w2', + bias_attr='D.b2') + return logits + + +def G(x): + hidden = fluid.layers.fc(input=x, + size=200, + act='relu', + param_attr='G.w1', + bias_attr='G.b1') + img = fluid.layers.fc(input=hidden, + size=28 * 28, + act='tanh', + param_attr='G.w2', + bias_attr='G.b2') + return img + + +def plot(gen_data): + gen_data.resize(gen_data.shape[0], 28, 28) + n = int(math.ceil(math.sqrt(gen_data.shape[0]))) + fig = plt.figure(figsize=(n, n)) + gs = gridspec.GridSpec(n, n) + gs.update(wspace=0.05, hspace=0.05) + + for i, sample in enumerate(gen_data): + ax = plt.subplot(gs[i]) + plt.axis('off') + ax.set_xticklabels([]) + ax.set_yticklabels([]) + ax.set_aspect('equal') + plt.imshow(sample.reshape(28, 28), cmap='Greys_r') + + return fig + + +def main(): + try: + os.makedirs("./out") + except OSError as e: + if e.errno != errno.EEXIST: + raise + + startup_program = fluid.Program() + d_program = fluid.Program() + dg_program = fluid.Program() + + with fluid.program_guard(d_program, startup_program): + img = fluid.layers.data(name='img', shape=[784], dtype='float32') + d_loss = fluid.layers.sigmoid_cross_entropy_with_logits( + x=D(img), + label=fluid.layers.data( + name='label', shape=[1], dtype='float32')) + d_loss = fluid.layers.mean(x=d_loss) + + with fluid.program_guard(dg_program, startup_program): + noise = fluid.layers.data( + name='noise', shape=[NOISE_SIZE], dtype='float32') + g_img = G(x=noise) + g_program = dg_program.clone() + dg_loss = fluid.layers.sigmoid_cross_entropy_with_logits( + x=D(g_img), + label=fluid.layers.fill_constant_batch_size_like( + input=noise, dtype='float32', shape=[-1, 1], value=1.0)) + dg_loss = fluid.layers.mean(x=dg_loss) + + opt = fluid.optimizer.Adam(learning_rate=LEARNING_RATE) + + opt.minimize(loss=d_loss, startup_program=startup_program) + opt.minimize( + loss=dg_loss, + startup_program=startup_program, + parameter_list=[ + p.name for p in g_program.global_block().all_parameters() + ]) + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(startup_program) + + num_true = NUM_REAL_IMGS_IN_BATCH + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.mnist.train(), buf_size=60000), + batch_size=num_true) + + for pass_id in range(NUM_PASS): + for batch_id, data in enumerate(train_reader()): + num_true = len(data) + n = numpy.random.uniform( + low=-1.0, high=1.0, + size=[num_true * NOISE_SIZE]).astype('float32').reshape( + [num_true, NOISE_SIZE]) + generated_img = exe.run(g_program, + feed={'noise': n}, + fetch_list={g_img})[0] + real_data = numpy.array(map(lambda x: x[0], data)).astype('float32') + real_data = real_data.reshape(num_true, 784) + total_data = numpy.concatenate([real_data, generated_img]) + total_label = numpy.concatenate([ + numpy.ones( + shape=[real_data.shape[0], 1], dtype='float32'), + numpy.zeros( + shape=[real_data.shape[0], 1], dtype='float32') + ]) + d_loss_np = exe.run(d_program, + feed={'img': total_data, + 'label': total_label}, + fetch_list={d_loss})[0] + for _ in xrange(NUM_TRAIN_TIMES_OF_DG): + n = numpy.random.uniform( + low=-1.0, high=1.0, + size=[2 * num_true * NOISE_SIZE]).astype('float32').reshape( + [2 * num_true, NOISE_SIZE, 1, 1]) + dg_loss_np = exe.run(dg_program, + feed={'noise': n}, + fetch_list={dg_loss})[0] + print("Pass ID={0}, Batch ID={1}, D-Loss={2}, DG-Loss={3}".format( + pass_id, batch_id, d_loss_np, dg_loss_np)) + # generate image each batch + fig = plot(generated_img) + plt.savefig( + 'out/{0}.png'.format(str(pass_id).zfill(3)), bbox_inches='tight') + plt.close(fig) + + +if __name__ == '__main__': + main()