Skip to content

Commit

Permalink
Update lora
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Nov 29, 2023
1 parent a821aef commit 1830f87
Show file tree
Hide file tree
Showing 29 changed files with 6,169 additions and 2,196 deletions.
597 changes: 423 additions & 174 deletions ppdiffusers/ppdiffusers/loaders.py

Large diffs are not rendered by default.

94 changes: 87 additions & 7 deletions ppdiffusers/ppdiffusers/models/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,94 @@
# limitations under the License.

import paddle.nn as nn
import paddle.nn.functional as F

from ..utils import USE_PEFT_BACKEND
from .lora import LoRACompatibleLinear

def get_activation(act_fn):
if act_fn in ["swish", "silu"]:
return nn.Silu()
elif act_fn == "mish":
return nn.Mish()
elif act_fn == "gelu":
return nn.GELU()
ACTIVATION_FUNCTIONS = {
"swish": nn.Silu(),
"silu": nn.Silu(),
"mish": nn.Mish(),
"gelu": nn.GELU(),
"relu": nn.ReLU(),
}


def get_activation(act_fn: str) -> nn.Layer:
"""Helper function to get activation function from string.
Args:
act_fn (str): Name of activation function.
Returns:
nn.Layer: Activation function.
"""

act_fn = act_fn.lower()
if act_fn in ACTIVATION_FUNCTIONS:
return ACTIVATION_FUNCTIONS[act_fn]
else:
raise ValueError(f"Unsupported activation function: {act_fn}")


class GELU(nn.Layer):
r"""
GELU activation function with tanh approximation support with `approximate="tanh"`.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
"""

def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
super().__init__()
self.proj = LoRACompatibleLinear(dim_in, dim_out)
self.approximate = approximate
self.approximate_bool = approximate == "tanh"

def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = F.gelu(hidden_states, approximate=self.approximate_bool)
return hidden_states


class GEGLU(nn.Layer):
r"""
A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
"""

def __init__(self, dim_in: int, dim_out: int):
super().__init__()
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear

self.proj = linear_cls(dim_in, dim_out * 2)

def forward(self, hidden_states, scale: float = 1.0):
args = () if USE_PEFT_BACKEND else (scale,)
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, axis=-1)
return hidden_states * F.gelu(gate)


class ApproximateGELU(nn.Layer):
r"""
The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
[paper](https://arxiv.org/abs/1606.08415).
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
"""

def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)

def forward(self, x):
x = self.proj(x)
return x * F.sigmoid(1.702 * x)
Loading

0 comments on commit 1830f87

Please sign in to comment.