Skip to content

Commit

Permalink
更改相关文件
Browse files Browse the repository at this point in the history
  • Loading branch information
yuchen202 committed Sep 1, 2023
1 parent 0d08135 commit a5ba675
Show file tree
Hide file tree
Showing 5 changed files with 255 additions and 254 deletions.
26 changes: 13 additions & 13 deletions python/paddle/incubate/operators/unzip.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,20 @@ def unzip(input, lod):
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.fluid as fluid
paddle.enable_static()
input_np = np.array([
[1.0, 2.0, 3.0, 4.0],
[10.0, 20.0, 30.0, 40.0],
[100.0, 200.0, 300.0, 400.0]
])
lod_np = np.array([0, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12])
input = paddle.to_tensor(input_np, "int64")
lod = paddle.to_tensor(lod_np, "int64")
>>> import numpy as np
>>> import paddle
>>> import paddle.fluid as fluid
>>> paddle.enable_static()
>>> input_np = np.array([
... [1.0, 2.0, 3.0, 4.0],
... [10.0, 20.0, 30.0, 40.0],
... [100.0, 200.0, 300.0, 400.0]
... ])
>>> lod_np = np.array([0, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12])
>>> input = paddle.to_tensor(input_np, "int64")
>>> lod = paddle.to_tensor(lod_np, "int64")
unzipped_input = paddle.incubate.unzip(input, lod)
>>> unzipped_input = paddle.incubate.unzip(input, lod)
'''
unzipped_input is [
[1.0, 2.0, 3.0, 4.0],
Expand Down
72 changes: 36 additions & 36 deletions python/paddle/incubate/optimizer/functional/bfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,46 +81,46 @@ def minimize_bfgs(
.. code-block:: python
:name: code-example1
# Example1: 1D Grid Parameters
import paddle
# Randomly simulate a batch of input data
inputs = paddle. normal(shape=(100, 1))
labels = inputs * 2.0
# define the loss function
def loss(w):
y = w * inputs
return paddle.nn.functional.square_error_cost(y, labels).mean()
# Initialize weight parameters
w = paddle.normal(shape=(1,))
# Call the bfgs method to solve the weight that makes the loss the smallest, and update the parameters
for epoch in range(0, 10):
# Call the bfgs method to optimize the loss, note that the third parameter returned represents the weight
w_update = paddle.incubate.optimizer.functional.minimize_bfgs(loss, w)[2]
# Use paddle.assign to update parameters in place
paddle. assign(w_update, w)
>>> # Example1: 1D Grid Parameters
>>> import paddle
>>> # Randomly simulate a batch of input data
>>> inputs = paddle. normal(shape=(100, 1))
>>> labels = inputs * 2.0
>>> # define the loss function
>>> def loss(w):
... y = w * inputs
... return paddle.nn.functional.square_error_cost(y, labels).mean()
>>> # Initialize weight parameters
>>> w = paddle.normal(shape=(1,))
>>> # Call the bfgs method to solve the weight that makes the loss the smallest, and update the parameters
>>> for epoch in range(0, 10):
... # Call the bfgs method to optimize the loss, note that the third parameter returned represents the weight
... w_update = paddle.incubate.optimizer.functional.minimize_bfgs(loss, w)[2]
... # Use paddle.assign to update parameters in place
... paddle. assign(w_update, w)
.. code-block:: python
:name: code-example2
# Example2: Multidimensional Grid Parameters
import paddle
def flatten(x):
return x. flatten()
def unflatten(x):
return x.reshape((2,2))
# Assume the network parameters are more than one dimension
def net(x):
assert len(x.shape) > 1
return x.square().mean()
# function to be optimized
def bfgs_f(flatten_x):
return net(unflatten(flatten_x))
x = paddle.rand([2,2])
for i in range(0, 10):
# Flatten x before using minimize_bfgs
x_update = paddle.incubate.optimizer.functional.minimize_bfgs(bfgs_f, flatten(x))[2]
# unflatten x_update, then update parameters
paddle. assign(unflatten(x_update), x)
>>> # Example2: Multidimensional Grid Parameters
>>> import paddle
>>> def flatten(x):
... return x. flatten()
>>> def unflatten(x):
... return x.reshape((2,2))
>>> # Assume the network parameters are more than one dimension
>>> def net(x):
... assert len(x.shape) > 1
... return x.square().mean()
>>> # function to be optimized
>>> def bfgs_f(flatten_x):
... return net(unflatten(flatten_x))
>>> x = paddle.rand([2,2])
>>> for i in range(0, 10):
... # Flatten x before using minimize_bfgs
... x_update = paddle.incubate.optimizer.functional.minimize_bfgs(bfgs_f, flatten(x))[2]
... # unflatten x_update, then update parameters
... paddle. assign(unflatten(x_update), x)
"""

if dtype not in ['float32', 'float64']:
Expand Down
72 changes: 36 additions & 36 deletions python/paddle/incubate/optimizer/functional/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,46 +82,46 @@ def minimize_lbfgs(
.. code-block:: python
:name: code-example1
# Example1: 1D Grid Parameters
import paddle
# Randomly simulate a batch of input data
inputs = paddle. normal(shape=(100, 1))
labels = inputs * 2.0
# define the loss function
def loss(w):
y = w * inputs
return paddle.nn.functional.square_error_cost(y, labels).mean()
# Initialize weight parameters
w = paddle.normal(shape=(1,))
# Call the bfgs method to solve the weight that makes the loss the smallest, and update the parameters
for epoch in range(0, 10):
# Call the bfgs method to optimize the loss, note that the third parameter returned represents the weight
w_update = paddle.incubate.optimizer.functional.minimize_bfgs(loss, w)[2]
# Use paddle.assign to update parameters in place
paddle. assign(w_update, w)
>>> # Example1: 1D Grid Parameters
>>> import paddle
>>> # Randomly simulate a batch of input data
>>> inputs = paddle. normal(shape=(100, 1))
>>> labels = inputs * 2.0
>>> # define the loss function
>>> def loss(w):
... y = w * inputs
... return paddle.nn.functional.square_error_cost(y, labels).mean()
>>> # Initialize weight parameters
>>> w = paddle.normal(shape=(1,))
>>> # Call the bfgs method to solve the weight that makes the loss the smallest, and update the parameters
>>> for epoch in range(0, 10):
... # Call the bfgs method to optimize the loss, note that the third parameter returned represents the weight
... w_update = paddle.incubate.optimizer.functional.minimize_bfgs(loss, w)[2]
... # Use paddle.assign to update parameters in place
... paddle. assign(w_update, w)
.. code-block:: python
:name: code-example2
# Example2: Multidimensional Grid Parameters
import paddle
def flatten(x):
return x. flatten()
def unflatten(x):
return x.reshape((2,2))
# Assume the network parameters are more than one dimension
def net(x):
assert len(x.shape) > 1
return x.square().mean()
# function to be optimized
def bfgs_f(flatten_x):
return net(unflatten(flatten_x))
x = paddle.rand([2,2])
for i in range(0, 10):
# Flatten x before using minimize_bfgs
x_update = paddle.incubate.optimizer.functional.minimize_bfgs(bfgs_f, flatten(x))[2]
# unflatten x_update, then update parameters
paddle. assign(unflatten(x_update), x)
>>> # Example2: Multidimensional Grid Parameters
>>> import paddle
>>> def flatten(x):
... return x. flatten()
>>> def unflatten(x):
... return x.reshape((2,2))
>>> # Assume the network parameters are more than one dimension
>>> def net(x):
... assert len(x.shape) > 1
... return x.square().mean()
>>> # function to be optimized
>>> def bfgs_f(flatten_x):
... return net(unflatten(flatten_x))
>>> x = paddle.rand([2,2])
>>> for i in range(0, 10):
... # Flatten x before using minimize_bfgs
... x_update = paddle.incubate.optimizer.functional.minimize_bfgs(bfgs_f, flatten(x))[2]
... # unflatten x_update, then update parameters
... paddle. assign(unflatten(x_update), x)
"""
if dtype not in ['float32', 'float64']:
Expand Down
Loading

0 comments on commit a5ba675

Please sign in to comment.