diff --git a/llm/config/qwen/emb_argument.json b/llm/config/qwen/emb_argument.json index 7bd4430fb1f5..55e60b137c1f 100644 --- a/llm/config/qwen/emb_argument.json +++ b/llm/config/qwen/emb_argument.json @@ -3,7 +3,7 @@ "dataset_name_or_path": "./data", "output_dir": "./checkpoints/sft_ckpts", "per_device_train_batch_size": 1, - "gradient_accumulation_steps": 128, + "gradient_accumulation_steps": 4, "per_device_eval_batch_size": 1, "eval_accumulation_steps": 1, "max_steps": 2000, @@ -15,7 +15,7 @@ "max_query_len": 1024, "max_passage_len": 2048, "group_size": 4, - "bp16": true, + "bf16": true, "fp16_opt_level": "O2", "do_train": true, "do_eval": false, @@ -30,5 +30,6 @@ "sharding": "stage2", "zero_padding": false, "unified_checkpoint": false, - "use_flash_attention": false + "use_flash_attention": true, + "amp_custom_black_list": "elementwise_div" } diff --git a/llm/utils/argument.py b/llm/utils/argument.py index 63a1267a5569..99df142e826e 100644 --- a/llm/utils/argument.py +++ b/llm/utils/argument.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass, field +from typing import List, Optional @dataclass @@ -83,3 +84,7 @@ class EmbeddingArgument: default=True, metadata={"help": "Whether to share the negatives across all GPUs."}, ) + embedding_matryoshka_dims: Optional[List[int]] = field( + default=None, + metadata={"help": "The dims for matryoshka training."}, + )