Skip to content

Commit

Permalink
[CoqTorch] Add preliminary activation function support
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonGross committed Oct 23, 2023
1 parent deeddd8 commit 4fe4f07
Show file tree
Hide file tree
Showing 9 changed files with 190 additions and 24 deletions.
2 changes: 1 addition & 1 deletion theories/MaxOfTwoNumbers/Parameters.v
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Module cfg <: CommonConfig.
Definition model_name := "custom".
Definition n_heads := 1%N.
Definition d_mlp := @None Z.
Definition act_fn := "relu".
Definition act_fn := @None ActivationKind.
Definition d_vocab := 64%N.
Definition eps := 1e-05.
Definition use_attn_result := false.
Expand Down
4 changes: 2 additions & 2 deletions theories/MaxOfTwoNumbersSimpler/Parameters.v
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Module cfg <: CommonConfig.
Definition model_name := "custom".
Definition n_heads := 1%N.
Definition d_mlp := @None Z.
Definition act_fn := @None string.
Definition act_fn := @None ActivationKind.
Definition d_vocab := 64%N.
Definition eps := 0x1.4f8b588e368f1p-17.
Definition use_attn_result := false.
Expand Down Expand Up @@ -454,4 +454,4 @@ Definition L0_attn_b_O :=
.
Definition L0_attn_b_V :=
[[0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0]]
.
.
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ Proof.
specialize (H' i' (Z.to_nat (Uint63.to_Z i')) pf));
[ cbv [Reduction.in_bounds_alt_at]; clear;
rewrite ?nat_N_Z, ?Z2Nat.id, ?of_to_Z by lia;
cbv [Classes.add Classes.mul Classes.zero int_has_add Classes.one int_has_one int_has_mul int_has_zero];
cbv [Classes.add Classes.mul Classes.zero Classes.max int_has_add Classes.one Classes.eqb Uint63.max int_has_eqb int_has_one int_has_mul int_has_zero has_default_max_leb Classes.leb int_has_leb Uint63.leb] in *;
match goal with
| [ |- context[?v] ]
=> lazymatch v with context[i'] => fail | context[if _ then _ else _] => idtac end;
Expand All @@ -382,6 +382,7 @@ Proof.
| nat => idtac
| int => idtac
| N => idtac
| bool => idtac
end;
let v' := (eval vm_compute in v) in
progress change v with v'
Expand All @@ -403,9 +404,9 @@ Proof.
cbv [inject_int] in *.
specialize_step i'.
specialize_step i'.
cbv [Classes.modulo int_has_modulo] in *.
cbv [Classes.modulo int_has_modulo Classes.max Uint63.max has_default_max_leb Classes.leb int_has_leb Uint63.leb] in *.
set (i'' := (i' mod _)%uint63) in *.
assert (i' = i'') by (clear; subst i' i''; nia).
assert (i' = i'') by (clear; subst i' i''; try nia).
clearbody i''; subst i''.
move indices_of_max at bottom.
subst min_incorrect_logit.
Expand Down
2 changes: 1 addition & 1 deletion theories/MaxOfTwoNumbersUndertrainedSimpler/Parameters.v
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Module cfg <: CommonConfig.
Definition model_name := "custom".
Definition n_heads := 1%N.
Definition d_mlp := @None Z.
Definition act_fn := @None string.
Definition act_fn := @None ActivationKind.
Definition d_vocab := 64%N.
Definition eps := 0x1.4f8b588e368f1p-17.
Definition use_attn_result := false.
Expand Down
5 changes: 4 additions & 1 deletion theories/Torch/Tensor.v
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ From NeuralNetInterp.Util Require Import Wf_Uint63 PArray.Proofs List.Proofs Def
Import Util.Nat.Notations.
Import Util.Wf_Uint63.LoopNotation.
Import Util.Wf_Uint63.Reduction.
Import Arith.Classes.
Import Instances.Uint63.
Import Arith.Classes.
Local Open Scope list_scope.
Set Implicit Arguments.
Import ListNotations.
Expand Down Expand Up @@ -1073,6 +1073,9 @@ Definition reshape {A r1 r2} {s1 : Shape r1} (t : tensor s1 A) (s2 : Shape r2) :
:= unreshape_m1 (reshape_m1 t : tensor (Shape.reshape s2) A).
*)

Definition relu {r} {s : Shape r} {A} {zeroA : has_zero A} {maxA : has_max A} (xs : tensor s A) : tensor s A
:= map (max 0) xs.

Section reduce_axis_m1.
Context {r} {s1 : Shape r} {s2 : ShapeType} {keepdim : with_default "keepdim" bool false}
{A}
Expand Down
17 changes: 17 additions & 0 deletions theories/Torch/Tensor/Instances.v
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,23 @@ Qed.
#[export] Instance unreshape_all_Proper {r s A R} : Proper (eqfR R ==> eqfR R) (@unreshape_all r s A).
Proof. apply unreshape_all_Proper_dep. Qed.

#[export] Instance relu_Proper_dep {r s}
: Dependent.Proper
(Dependent.idR
==> (Dependent.idR ==> Dependent.idR ==> Dependent.idR)
==> eqfR
==> eqfR)
(@relu r s).
Proof.
repeat intro; cbv [relu].
cbv -[RawIndex tensor Shape map] in *.
eapply map_Proper_dep; repeat intro; hnf in *; eauto.
Qed.

#[export] Instance relu_Proper {r s A zeroA maxA}
: Proper (eqf ==> eqf) (@relu r s A zeroA maxA).
Proof. apply relu_Proper_dep; repeat intro; subst; reflexivity. Qed.

#[export] Instance sum_dim_m1_Proper_dep {r s1 s2 keepdim}
: Dependent.Proper
(Dependent.idR
Expand Down
136 changes: 136 additions & 0 deletions theories/TransformerLens/HookedTransformer.v
Original file line number Diff line number Diff line change
@@ -1,3 +1,90 @@
(** This file corresponds roughly to components.py from
transformer_lens at
https://github.com/neelnanda-io/TransformerLens/blob/main/transformer_lens/components.py
In this file, each class in components.py is turned into a Coq
Module purely for namespacing / code organizational purposes.
- In this file, NN configuration parameters are taken as arguments
to each function.
- In [HookedTransformer/Module.v], we organize the same building
blocks into module functors where the configuration parameters
are taken in as module functor arguments, using the code defined
in this file.
- This file allows the potential for proving theorems about the
functions that are universally quantified over all parameters,
while the module functor organization allows the potential for
avoiding the overhead of passing around the parameters to every
function call (hopefully, I'm still figuring out ideal design
choices here).
- It might be better to combine the files and stick with just the
module functor based organization, since that's the one we
ultimately use.
*)
(** Design principles:
Two goals for code in this file:
1. Be able to run the code (efficiently, on primitive floats and
arrays)
2. Be able to reuse the same code to prove theorems about the
computation (using, e.g., reals or rationals instead of
primitive floats, so we get more nice mathematical properties).
As a result, we want to:
- parameterize each function over the type of data (and the
associated operations on that datatype that we need)
- order the arguments so that we can define [Proper] instances
relating instantiations on different datatypes; arguments that
are the same across instantiations (such as vocab size) come
first, while arguments that vary (such as how to do addition on
the datatype) come later.
Additionally, we parameterize functions over the batch size and
the shape of the tensor.
Ordering of the particular datatype operations is currently a bit
ad-hoc and disorganized, though some effort is made to maintain
consistency across components.
In each component, arguments are specified with a [Context]
directive in an unnamed [Section] to allow sharing of argument
declarations between different functions within that component.
Coq automatically determines the subset of arguments used for each
function.
- Most arguments are implicit and maximally inserted (using curly
braces [{}]) so that they are picked up by type inference or
typeclass resolution. Exceptions are the weights and biases and
the tensors that are passed into the python code.
- We pass a [use_checkpoint : with_default "use_checkpoint" bool
true] argument which specifies whether or not we materialize
concrete arrays at various points in the code.
- Because our tensors are represented as functions from indices
to data, by default computation is lazy and deferred until
concrete indices are given.
- This is useful to avoid duplicating computation when
broadcasting scalars, but would involve massive duplication in
other cases such as recomputing the same mean repeatedly for
every index separately.
- Hence we use [PArray.maybe_checkpoint] to materialize
computations before any step that might duplicate them.
- Materializing arrays gets in the way of proofs, so we use this
parameter to ease infrastructure for removing all array
materialization simultaneously in a single proof step. *)

From Coq Require Import Floats Sint63 Uint63 QArith Lia List PArray Morphisms RelationClasses.
From NeuralNetInterp.Util Require Import Default Pointed PArray List Notations Arith.Classes Arith.Instances Bool PrimitiveProd.
From NeuralNetInterp.Util Require Nat Wf_Uint63.
Expand All @@ -13,6 +100,7 @@ Local Open Scope list_scope.
Set Implicit Arguments.
Import ListNotations.
Local Open Scope raw_tensor_scope.
Local Open Scope core_scope.

Notation tensor_of_list ls := (Tensor.PArray.abstract (Tensor.PArray.concretize (Tensor.of_list ls))) (only parsing).

Expand Down Expand Up @@ -219,6 +307,54 @@ Module Attention.
End __.
End Attention.

Module MLP.
Section __.
Context {r} {batch : Shape r}
{pos d_model d_mlp}
(act_fn_kind : ActivationKind)
{A}
{addA : has_add A} {mulA : has_mul A}
{maxA : has_max A}
{zeroA : has_zero A}
{use_checkpoint : with_default "use_checkpoint" bool true}
(W_in : tensor [d_model; d_mlp] A) (b_in : tensor [d_mlp] A)
(W_out : tensor [d_mlp; d_model] A) (b_out : tensor [d_model] A)
(x : tensor (batch ::' pos ::' d_model) A)
.

Definition pre_act : tensor (batch ::' pos ::' d_mlp) A
:= let x' : tensor (batch ::' pos ::' d_mlp) A
:= Tensor.map'
(fun x : tensor [pos; d_mlp] A
=> weaksauce_einsum {{{ {{ pos d_model , d_model d_mlp -> pos d_mlp }}
, x
, W_in }}}
: tensor [pos; d_mlp] A)
x in
x' + broadcast b_in.

Definition act_fn : tensor (batch ::' pos ::' d_mlp) A -> tensor (batch ::' pos ::' d_mlp) A
:= match act_fn_kind with
| relu => Tensor.relu
end.

(* TODO: if act_fn is *_ln, then handle ln *)
Definition post_act : tensor (batch ::' pos ::' d_mlp) A
:= act_fn pre_act.

Definition forward : tensor (batch ::' pos ::' d_model) A
:= let fx' : tensor (batch ::' pos ::' d_model) A
:= Tensor.map'
(fun fx : tensor [pos; d_mlp] A
=> weaksauce_einsum {{{ {{ pos d_mlp , d_mlp d_model -> pos d_model }}
, fx
, W_out }}}
: tensor [pos; d_mlp] A)
post_act in
fx' + broadcast b_out.
End __.
End MLP.

Module TransformerBlock.
Section __.
Context {r} {batch : Shape r}
Expand Down
24 changes: 14 additions & 10 deletions theories/TransformerLens/HookedTransformer/Config/Common.v
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
(** Ported from https://github.com/neelnanda-io/TransformerLens/blob/main/transformer_lens/HookedTransformerConfig.py *)
From Coq Require Import Floats Uint63 ZArith NArith.
From NeuralNetInterp.Util Require Import Default.
From NeuralNetInterp.Util Require Import Option Default.
(** Copying the docstring from Python:
<<<
Expand Down Expand Up @@ -124,6 +124,14 @@ From NeuralNetInterp.Util Require Import Default.
>>> *)

Variant NormalizationType := LN (* | LNPre *) .
Variant ActivationKind :=
| relu
(* | gelu *)
(* | silu *)
(* | gelu_new *)
(* | solu_ln *)
(* | gelu_fast *)
.

Module Type CommonConfig.
Parameter d_model : N.
Expand All @@ -132,8 +140,12 @@ Module Type CommonConfig.
Parameter d_vocab : N.
Parameter d_vocab_out : N.
Parameter n_heads : N.
Parameter eps : float.
#[local] Set Warnings "-inexact-float".
Parameter eps : with_default "eps" float (1e-5)%float.
#[local] Set Warnings "inexact-float".
Parameter normalization_type : with_default "normalization_type" (option NormalizationType) (Some LN).
Parameter act_fn : with_default "act_fn" (option ActivationKind) None.
Definition attn_only : with_default "attn_only" bool false := Option.is_None act_fn.
(*Parameter use_split_qkv_input : with_default "use_split_qkv_input" bool false.*)
(*Notation maybe_n_heads := (if use_split_qkv_input as b return Shape (if b then _ else _) then [n_heads] else [])%shape (only parsing).*)
End CommonConfig.
Expand All @@ -149,14 +161,6 @@ Import RecordSetNotations.
#[local] Set Decidable Equality Schemes.
#[local] Set Boolean Equality Schemes.
Variant ActivationKind :=
| relu
(* | gelu *)
(* | silu *)
(* | gelu_new *)
(* | solu_ln *)
(* | gelu_fast *)
.
(** Copying the docstring from Python:
Expand Down
17 changes: 11 additions & 6 deletions training/coq_export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def strify(v, ty=None, description=None, parens_if_space=False):
tymap = {'int': 'Z', 'float':'Q', 'str':'string', 'bool':'bool', 'NormalizationType':'NormalizationType'}
tymap = {'int': 'Z', 'float':'Q', 'str':'string', 'bool':'bool', 'NormalizationType':'NormalizationType', 'ActivationKind':'ActivationKind'}
def wrap_parens(s):
return f'({s})' if parens_if_space else s
if v is None:
Expand All @@ -23,7 +23,7 @@ def wrap_parens(s):
ty = ty[len('Optional['):-1]
return wrap_parens(f'Some {strify(v, ty=ty, description=description, parens_if_space=True)}')
if isinstance(v, bool): return 'true' if v else 'false'
if isinstance(v, str) and ty == 'NormalizationType': return v
if isinstance(v, str) and ty in ('NormalizationType', 'ActivationKind'): return v
if isinstance(v, str): return '"' + repr(v)[1:-1] + '"'
if isinstance(v, torch.Tensor): return strify(v.detach().cpu().numpy(), ty=ty, description=description, parens_if_space=parens_if_space)
if isinstance(v, np.ndarray):
Expand Down Expand Up @@ -55,8 +55,13 @@ def coq_export_params_lines(model: HookedTransformer) -> Iterable[str]:
for f in dataclasses.fields(model.cfg):
val = dataclasses.asdict(model.cfg)[f.name]
ty = f.type
if f.name == 'attn_types' and ty == 'Optional[List]': ty = 'Optional[List[str]]'
if f.name == 'normalization_type' and ty == 'Optional[str]': ty = 'Optional[NormalizationType]'
for (name, expty, newty) in [('attn_types', 'Optional[List]', 'Optional[List[str]]'),
('normalization_type', 'Optional[str]', 'Optional[NormalizationType]'),
('act_fn', 'Optional[str]', 'Optional[ActivationKind]')]:
if f.name == name:
assert ty == expty, f'{f.name}.ty == {ty} != {expty}'
ty = newty
break
yield f' Definition {f.name} := {strify(val, ty=ty, description=f.name)}.'
yield 'End cfg.'

Expand All @@ -83,7 +88,7 @@ def coq_export_params_lines(model: HookedTransformer) -> Iterable[str]:
yield strify(getattr(model, name))
yield '.'


for layer, block in enumerate(model.blocks):
for module, names in (('ln1', ('b', 'w')), ('attn', ('W_Q', 'W_K', 'W_O', 'W_V', 'b_Q', 'b_K', 'b_O', 'b_V')), ):
if hasattr(block, module):
Expand All @@ -101,4 +106,4 @@ def coq_export_params_lines(model: HookedTransformer) -> Iterable[str]:
yield '.'
# %%
def coq_export_params(model: HookedTransformer):
return '\n'.join(coq_export_params_lines(model))
return '\n'.join(coq_export_params_lines(model))

0 comments on commit 4fe4f07

Please sign in to comment.