From 4fe4f0734c7f4d60bfac0a3979ab37fbcea7f0e5 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Mon, 23 Oct 2023 15:29:07 -0700 Subject: [PATCH] [CoqTorch] Add preliminary activation function support --- theories/MaxOfTwoNumbers/Parameters.v | 2 +- theories/MaxOfTwoNumbersSimpler/Parameters.v | 4 +- .../Theorems/Attempt_01b_logit_delta.v | 7 +- .../Parameters.v | 2 +- theories/Torch/Tensor.v | 5 +- theories/Torch/Tensor/Instances.v | 17 +++ theories/TransformerLens/HookedTransformer.v | 136 ++++++++++++++++++ .../HookedTransformer/Config/Common.v | 24 ++-- training/coq_export_utils.py | 17 ++- 9 files changed, 190 insertions(+), 24 deletions(-) diff --git a/theories/MaxOfTwoNumbers/Parameters.v b/theories/MaxOfTwoNumbers/Parameters.v index c0458bd..c69a7f2 100644 --- a/theories/MaxOfTwoNumbers/Parameters.v +++ b/theories/MaxOfTwoNumbers/Parameters.v @@ -16,7 +16,7 @@ Module cfg <: CommonConfig. Definition model_name := "custom". Definition n_heads := 1%N. Definition d_mlp := @None Z. - Definition act_fn := "relu". + Definition act_fn := @None ActivationKind. Definition d_vocab := 64%N. Definition eps := 1e-05. Definition use_attn_result := false. diff --git a/theories/MaxOfTwoNumbersSimpler/Parameters.v b/theories/MaxOfTwoNumbersSimpler/Parameters.v index b37deb4..0178425 100644 --- a/theories/MaxOfTwoNumbersSimpler/Parameters.v +++ b/theories/MaxOfTwoNumbersSimpler/Parameters.v @@ -14,7 +14,7 @@ Module cfg <: CommonConfig. Definition model_name := "custom". Definition n_heads := 1%N. Definition d_mlp := @None Z. - Definition act_fn := @None string. + Definition act_fn := @None ActivationKind. Definition d_vocab := 64%N. Definition eps := 0x1.4f8b588e368f1p-17. Definition use_attn_result := false. @@ -454,4 +454,4 @@ Definition L0_attn_b_O := . Definition L0_attn_b_V := [[0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0;0x0.0p+0]] -. \ No newline at end of file +. diff --git a/theories/MaxOfTwoNumbersSimpler/Theorems/Attempt_01b_logit_delta.v b/theories/MaxOfTwoNumbersSimpler/Theorems/Attempt_01b_logit_delta.v index 975170d..38430c6 100644 --- a/theories/MaxOfTwoNumbersSimpler/Theorems/Attempt_01b_logit_delta.v +++ b/theories/MaxOfTwoNumbersSimpler/Theorems/Attempt_01b_logit_delta.v @@ -373,7 +373,7 @@ Proof. specialize (H' i' (Z.to_nat (Uint63.to_Z i')) pf)); [ cbv [Reduction.in_bounds_alt_at]; clear; rewrite ?nat_N_Z, ?Z2Nat.id, ?of_to_Z by lia; - cbv [Classes.add Classes.mul Classes.zero int_has_add Classes.one int_has_one int_has_mul int_has_zero]; + cbv [Classes.add Classes.mul Classes.zero Classes.max int_has_add Classes.one Classes.eqb Uint63.max int_has_eqb int_has_one int_has_mul int_has_zero has_default_max_leb Classes.leb int_has_leb Uint63.leb] in *; match goal with | [ |- context[?v] ] => lazymatch v with context[i'] => fail | context[if _ then _ else _] => idtac end; @@ -382,6 +382,7 @@ Proof. | nat => idtac | int => idtac | N => idtac + | bool => idtac end; let v' := (eval vm_compute in v) in progress change v with v' @@ -403,9 +404,9 @@ Proof. cbv [inject_int] in *. specialize_step i'. specialize_step i'. - cbv [Classes.modulo int_has_modulo] in *. + cbv [Classes.modulo int_has_modulo Classes.max Uint63.max has_default_max_leb Classes.leb int_has_leb Uint63.leb] in *. set (i'' := (i' mod _)%uint63) in *. - assert (i' = i'') by (clear; subst i' i''; nia). + assert (i' = i'') by (clear; subst i' i''; try nia). clearbody i''; subst i''. move indices_of_max at bottom. subst min_incorrect_logit. diff --git a/theories/MaxOfTwoNumbersUndertrainedSimpler/Parameters.v b/theories/MaxOfTwoNumbersUndertrainedSimpler/Parameters.v index 235cc52..21170ba 100644 --- a/theories/MaxOfTwoNumbersUndertrainedSimpler/Parameters.v +++ b/theories/MaxOfTwoNumbersUndertrainedSimpler/Parameters.v @@ -14,7 +14,7 @@ Module cfg <: CommonConfig. Definition model_name := "custom". Definition n_heads := 1%N. Definition d_mlp := @None Z. - Definition act_fn := @None string. + Definition act_fn := @None ActivationKind. Definition d_vocab := 64%N. Definition eps := 0x1.4f8b588e368f1p-17. Definition use_attn_result := false. diff --git a/theories/Torch/Tensor.v b/theories/Torch/Tensor.v index 8ba99bc..b05edf0 100644 --- a/theories/Torch/Tensor.v +++ b/theories/Torch/Tensor.v @@ -6,8 +6,8 @@ From NeuralNetInterp.Util Require Import Wf_Uint63 PArray.Proofs List.Proofs Def Import Util.Nat.Notations. Import Util.Wf_Uint63.LoopNotation. Import Util.Wf_Uint63.Reduction. -Import Arith.Classes. Import Instances.Uint63. +Import Arith.Classes. Local Open Scope list_scope. Set Implicit Arguments. Import ListNotations. @@ -1073,6 +1073,9 @@ Definition reshape {A r1 r2} {s1 : Shape r1} (t : tensor s1 A) (s2 : Shape r2) : := unreshape_m1 (reshape_m1 t : tensor (Shape.reshape s2) A). *) +Definition relu {r} {s : Shape r} {A} {zeroA : has_zero A} {maxA : has_max A} (xs : tensor s A) : tensor s A + := map (max 0) xs. + Section reduce_axis_m1. Context {r} {s1 : Shape r} {s2 : ShapeType} {keepdim : with_default "keepdim" bool false} {A} diff --git a/theories/Torch/Tensor/Instances.v b/theories/Torch/Tensor/Instances.v index 71aaecf..c836665 100644 --- a/theories/Torch/Tensor/Instances.v +++ b/theories/Torch/Tensor/Instances.v @@ -427,6 +427,23 @@ Qed. #[export] Instance unreshape_all_Proper {r s A R} : Proper (eqfR R ==> eqfR R) (@unreshape_all r s A). Proof. apply unreshape_all_Proper_dep. Qed. + #[export] Instance relu_Proper_dep {r s} + : Dependent.Proper + (Dependent.idR + ==> (Dependent.idR ==> Dependent.idR ==> Dependent.idR) + ==> eqfR + ==> eqfR) + (@relu r s). + Proof. + repeat intro; cbv [relu]. + cbv -[RawIndex tensor Shape map] in *. + eapply map_Proper_dep; repeat intro; hnf in *; eauto. + Qed. + + #[export] Instance relu_Proper {r s A zeroA maxA} + : Proper (eqf ==> eqf) (@relu r s A zeroA maxA). + Proof. apply relu_Proper_dep; repeat intro; subst; reflexivity. Qed. + #[export] Instance sum_dim_m1_Proper_dep {r s1 s2 keepdim} : Dependent.Proper (Dependent.idR diff --git a/theories/TransformerLens/HookedTransformer.v b/theories/TransformerLens/HookedTransformer.v index bdf6139..08577d9 100644 --- a/theories/TransformerLens/HookedTransformer.v +++ b/theories/TransformerLens/HookedTransformer.v @@ -1,3 +1,90 @@ +(** This file corresponds roughly to components.py from + transformer_lens at + https://github.com/neelnanda-io/TransformerLens/blob/main/transformer_lens/components.py + + In this file, each class in components.py is turned into a Coq + Module purely for namespacing / code organizational purposes. + + - In this file, NN configuration parameters are taken as arguments + to each function. + + - In [HookedTransformer/Module.v], we organize the same building + blocks into module functors where the configuration parameters + are taken in as module functor arguments, using the code defined + in this file. + + - This file allows the potential for proving theorems about the + functions that are universally quantified over all parameters, + while the module functor organization allows the potential for + avoiding the overhead of passing around the parameters to every + function call (hopefully, I'm still figuring out ideal design + choices here). + + - It might be better to combine the files and stick with just the + module functor based organization, since that's the one we + ultimately use. + *) +(** Design principles: + + Two goals for code in this file: + + 1. Be able to run the code (efficiently, on primitive floats and + arrays) + + 2. Be able to reuse the same code to prove theorems about the + computation (using, e.g., reals or rationals instead of + primitive floats, so we get more nice mathematical properties). + + As a result, we want to: + + - parameterize each function over the type of data (and the + associated operations on that datatype that we need) + + - order the arguments so that we can define [Proper] instances + relating instantiations on different datatypes; arguments that + are the same across instantiations (such as vocab size) come + first, while arguments that vary (such as how to do addition on + the datatype) come later. + + Additionally, we parameterize functions over the batch size and + the shape of the tensor. + + Ordering of the particular datatype operations is currently a bit + ad-hoc and disorganized, though some effort is made to maintain + consistency across components. + + In each component, arguments are specified with a [Context] + directive in an unnamed [Section] to allow sharing of argument + declarations between different functions within that component. + Coq automatically determines the subset of arguments used for each + function. + + - Most arguments are implicit and maximally inserted (using curly + braces [{}]) so that they are picked up by type inference or + typeclass resolution. Exceptions are the weights and biases and + the tensors that are passed into the python code. + + + - We pass a [use_checkpoint : with_default "use_checkpoint" bool + true] argument which specifies whether or not we materialize + concrete arrays at various points in the code. + + - Because our tensors are represented as functions from indices + to data, by default computation is lazy and deferred until + concrete indices are given. + + - This is useful to avoid duplicating computation when + broadcasting scalars, but would involve massive duplication in + other cases such as recomputing the same mean repeatedly for + every index separately. + + - Hence we use [PArray.maybe_checkpoint] to materialize + computations before any step that might duplicate them. + + - Materializing arrays gets in the way of proofs, so we use this + parameter to ease infrastructure for removing all array + materialization simultaneously in a single proof step. *) + From Coq Require Import Floats Sint63 Uint63 QArith Lia List PArray Morphisms RelationClasses. From NeuralNetInterp.Util Require Import Default Pointed PArray List Notations Arith.Classes Arith.Instances Bool PrimitiveProd. From NeuralNetInterp.Util Require Nat Wf_Uint63. @@ -13,6 +100,7 @@ Local Open Scope list_scope. Set Implicit Arguments. Import ListNotations. Local Open Scope raw_tensor_scope. +Local Open Scope core_scope. Notation tensor_of_list ls := (Tensor.PArray.abstract (Tensor.PArray.concretize (Tensor.of_list ls))) (only parsing). @@ -219,6 +307,54 @@ Module Attention. End __. End Attention. +Module MLP. + Section __. + Context {r} {batch : Shape r} + {pos d_model d_mlp} + (act_fn_kind : ActivationKind) + {A} + {addA : has_add A} {mulA : has_mul A} + {maxA : has_max A} + {zeroA : has_zero A} + {use_checkpoint : with_default "use_checkpoint" bool true} + (W_in : tensor [d_model; d_mlp] A) (b_in : tensor [d_mlp] A) + (W_out : tensor [d_mlp; d_model] A) (b_out : tensor [d_model] A) + (x : tensor (batch ::' pos ::' d_model) A) + . + + Definition pre_act : tensor (batch ::' pos ::' d_mlp) A + := let x' : tensor (batch ::' pos ::' d_mlp) A + := Tensor.map' + (fun x : tensor [pos; d_mlp] A + => weaksauce_einsum {{{ {{ pos d_model , d_model d_mlp -> pos d_mlp }} + , x + , W_in }}} + : tensor [pos; d_mlp] A) + x in + x' + broadcast b_in. + + Definition act_fn : tensor (batch ::' pos ::' d_mlp) A -> tensor (batch ::' pos ::' d_mlp) A + := match act_fn_kind with + | relu => Tensor.relu + end. + + (* TODO: if act_fn is *_ln, then handle ln *) + Definition post_act : tensor (batch ::' pos ::' d_mlp) A + := act_fn pre_act. + + Definition forward : tensor (batch ::' pos ::' d_model) A + := let fx' : tensor (batch ::' pos ::' d_model) A + := Tensor.map' + (fun fx : tensor [pos; d_mlp] A + => weaksauce_einsum {{{ {{ pos d_mlp , d_mlp d_model -> pos d_model }} + , fx + , W_out }}} + : tensor [pos; d_mlp] A) + post_act in + fx' + broadcast b_out. + End __. +End MLP. + Module TransformerBlock. Section __. Context {r} {batch : Shape r} diff --git a/theories/TransformerLens/HookedTransformer/Config/Common.v b/theories/TransformerLens/HookedTransformer/Config/Common.v index e6fb8de..f86ed2f 100644 --- a/theories/TransformerLens/HookedTransformer/Config/Common.v +++ b/theories/TransformerLens/HookedTransformer/Config/Common.v @@ -1,6 +1,6 @@ (** Ported from https://github.com/neelnanda-io/TransformerLens/blob/main/transformer_lens/HookedTransformerConfig.py *) From Coq Require Import Floats Uint63 ZArith NArith. -From NeuralNetInterp.Util Require Import Default. +From NeuralNetInterp.Util Require Import Option Default. (** Copying the docstring from Python: <<< @@ -124,6 +124,14 @@ From NeuralNetInterp.Util Require Import Default. >>> *) Variant NormalizationType := LN (* | LNPre *) . +Variant ActivationKind := + | relu +(* | gelu *) +(* | silu *) +(* | gelu_new *) +(* | solu_ln *) +(* | gelu_fast *) +. Module Type CommonConfig. Parameter d_model : N. @@ -132,8 +140,12 @@ Module Type CommonConfig. Parameter d_vocab : N. Parameter d_vocab_out : N. Parameter n_heads : N. - Parameter eps : float. + #[local] Set Warnings "-inexact-float". + Parameter eps : with_default "eps" float (1e-5)%float. + #[local] Set Warnings "inexact-float". Parameter normalization_type : with_default "normalization_type" (option NormalizationType) (Some LN). + Parameter act_fn : with_default "act_fn" (option ActivationKind) None. + Definition attn_only : with_default "attn_only" bool false := Option.is_None act_fn. (*Parameter use_split_qkv_input : with_default "use_split_qkv_input" bool false.*) (*Notation maybe_n_heads := (if use_split_qkv_input as b return Shape (if b then _ else _) then [n_heads] else [])%shape (only parsing).*) End CommonConfig. @@ -149,14 +161,6 @@ Import RecordSetNotations. #[local] Set Decidable Equality Schemes. #[local] Set Boolean Equality Schemes. -Variant ActivationKind := - | relu -(* | gelu *) -(* | silu *) -(* | gelu_new *) -(* | solu_ln *) -(* | gelu_fast *) -. (** Copying the docstring from Python: diff --git a/training/coq_export_utils.py b/training/coq_export_utils.py index 21395a9..fb0fe9e 100644 --- a/training/coq_export_utils.py +++ b/training/coq_export_utils.py @@ -7,7 +7,7 @@ def strify(v, ty=None, description=None, parens_if_space=False): - tymap = {'int': 'Z', 'float':'Q', 'str':'string', 'bool':'bool', 'NormalizationType':'NormalizationType'} + tymap = {'int': 'Z', 'float':'Q', 'str':'string', 'bool':'bool', 'NormalizationType':'NormalizationType', 'ActivationKind':'ActivationKind'} def wrap_parens(s): return f'({s})' if parens_if_space else s if v is None: @@ -23,7 +23,7 @@ def wrap_parens(s): ty = ty[len('Optional['):-1] return wrap_parens(f'Some {strify(v, ty=ty, description=description, parens_if_space=True)}') if isinstance(v, bool): return 'true' if v else 'false' - if isinstance(v, str) and ty == 'NormalizationType': return v + if isinstance(v, str) and ty in ('NormalizationType', 'ActivationKind'): return v if isinstance(v, str): return '"' + repr(v)[1:-1] + '"' if isinstance(v, torch.Tensor): return strify(v.detach().cpu().numpy(), ty=ty, description=description, parens_if_space=parens_if_space) if isinstance(v, np.ndarray): @@ -55,8 +55,13 @@ def coq_export_params_lines(model: HookedTransformer) -> Iterable[str]: for f in dataclasses.fields(model.cfg): val = dataclasses.asdict(model.cfg)[f.name] ty = f.type - if f.name == 'attn_types' and ty == 'Optional[List]': ty = 'Optional[List[str]]' - if f.name == 'normalization_type' and ty == 'Optional[str]': ty = 'Optional[NormalizationType]' + for (name, expty, newty) in [('attn_types', 'Optional[List]', 'Optional[List[str]]'), + ('normalization_type', 'Optional[str]', 'Optional[NormalizationType]'), + ('act_fn', 'Optional[str]', 'Optional[ActivationKind]')]: + if f.name == name: + assert ty == expty, f'{f.name}.ty == {ty} != {expty}' + ty = newty + break yield f' Definition {f.name} := {strify(val, ty=ty, description=f.name)}.' yield 'End cfg.' @@ -83,7 +88,7 @@ def coq_export_params_lines(model: HookedTransformer) -> Iterable[str]: yield strify(getattr(model, name)) yield '.' - + for layer, block in enumerate(model.blocks): for module, names in (('ln1', ('b', 'w')), ('attn', ('W_Q', 'W_K', 'W_O', 'W_V', 'b_Q', 'b_K', 'b_O', 'b_V')), ): if hasattr(block, module): @@ -101,4 +106,4 @@ def coq_export_params_lines(model: HookedTransformer) -> Iterable[str]: yield '.' # %% def coq_export_params(model: HookedTransformer): - return '\n'.join(coq_export_params_lines(model)) \ No newline at end of file + return '\n'.join(coq_export_params_lines(model))