diff --git a/model/README.md b/model/README.md index a758105aae..59d1cefb9a 100644 --- a/model/README.md +++ b/model/README.md @@ -16,15 +16,49 @@ export DATA_PATH=$PWD/.cache export MODEL_PATH=$PWD/.saved_models ``` -2. Then download the OA data. +2. Then download the OA message tree JSONL file or declare the HuggingFace + dataset to use. + +Create a new or modify an existing configuration section in the `config.yaml` +(SFT), `config_rm.yaml` (RM) or `config_rl.yaml` (RL) YAML configuration files +located in the `model_training/configs/` directory and specify the OA JSONL data +file or HuggingFace dataset to use. + +- To use a local OASST JSONL file (either `.jsonl` or `.jsonl.gz`) specify the + file name with the `input_file_path` configuration option. Place the file + either in the `cache_dir` (`DATA_PATH`) or specify an absolute path. ```bash -cp /path/to/ $DATA_PATH +cp /path/to/ $DATA_PATH +``` + +Example: + +```yaml +my_data_config: + datasets: + - oasst_export: + input_file_path: oasst_export.trees.jsonl.gz ``` -Change the `` file used in the `model_training/configs/config.yaml`, -`model_training/configs/config_rl.yaml` and `reward/instructor/rank_datasets.py` -files. +- To use a HuggingFace dataset specify the dataset name with the + `hf_dataset_name` configuration option. + +Example: + +```yaml +my_data_config: + datasets: + - oasst_export: + hf_dataset_name: OpenAssistant/oasst1 +``` + +_Note_: If both `hf_dataset_name` and `input_file_path` are specified +`input_file_path` will take precedence. + +See the +[OpenAssistant/oasst1](https://huggingface.co/datasets/OpenAssistant/oasst1) +dataset card on the HuggingFace hub for more information. - (TODO) add better parsing of the config files that is consistent for sft, rm and rl training. diff --git a/model/model_training/configs/config.yaml b/model/model_training/configs/config.yaml index 52fd065896..a8fc2ea92a 100644 --- a/model/model_training/configs/config.yaml +++ b/model/model_training/configs/config.yaml @@ -186,7 +186,9 @@ oasst_only: datasets: - oasst_export: lang: "bg,ca,cs,da,de,en,es,fr,hr,hu,it,nl,pl,pt,ro,ru,sl,sr,sv,uk" - input_file_path: 2023-04-04_oasst_ready.jsonl.gz + hf_dataset_name: OpenAssistant/oasst1 + #input_file_path: 2023-04-12_oasst_ready.trees.jsonl.gz + #top_k: 1 val_split: 0.05 sort_by_length: false use_custom_sampler: false @@ -206,14 +208,28 @@ oasst_export_eu: datasets: - oasst_export: lang: "en,es,de,fr" - input_file_path: 2023-03-27_oasst_research_ready_synth.jsonl.gz + hf_dataset_name: OpenAssistant/oasst1 + - gpt4all + - alpaca + - code_alpaca + - oig_file: + source_url: https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl + max_count: 10000 + min_length: 100 + val_split: 0.1 + - oig_file: + source_url: https://huggingface.co/datasets/laion/OIG/raw/main/unified_grade_school_math_instructions.jsonl + val_split: 0.1 + min_length: 100 + sort_by_length: false + use_custom_sampler: false oasst_export_latin_cyrillic: save_strategy: epoch datasets: - oasst_export: lang: "bg,ca,cs,da,de,en,es,fr,hr,hu,it,nl,pl,pt,ro,ru,sl,sr,sv,uk" - input_file_path: 2023-03-27_oasst_research_ready_synth.jsonl.gz + hf_dataset_name: OpenAssistant/oasst1 - alpaca - oig_file: source_url: https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl @@ -364,7 +380,7 @@ llama-30b-sft-6: datasets: - oasst_export: lang: "bg,ca,cs,da,de,en,es,fr,hr,hu,it,nl,pl,pt,ro,ru,sl,sr,sv,uk" - input_file_path: 2023-04-12_oasst_release_ready_synth.jsonl.gz + hf_dataset_name: OpenAssistant/oasst1 val_split: 0.05 - vicuna: val_split: 0.05 @@ -712,6 +728,7 @@ galactica-125m: gradient_accumulation_steps: 2 per_device_train_batch_size: 4 per_device_eval_batch_size: 4 + dtype: fp32 gpt-jt: learning_rate: 8e-6 @@ -761,3 +778,4 @@ debug: log_wandb: false verbose: true num_train_epochs: 0.2 + dtype: fp32 diff --git a/model/model_training/configs/config_rm.yaml b/model/model_training/configs/config_rm.yaml index da7a2130b4..05741dfda3 100644 --- a/model/model_training/configs/config_rm.yaml +++ b/model/model_training/configs/config_rm.yaml @@ -49,7 +49,6 @@ oasst-rm-1-pythia-6.9b: pooling: last datasets: - augment_oasst: - #input_file_path: augmented_latin_cyrillic_oasst_2023-03-27.jsonl input_file_path: augmented_latin_cyrillic_oasst_2023-03-27_v2.jsonl - anthropic_rlhf: fraction: 0.1 @@ -98,10 +97,9 @@ oasst-rm-1-pythia-2.8b: datasets: - oasst_export: lang: "en,es,de,fr" - input_file_path: 2023-03-27_oasst_research_ready_synth.jsonl.gz + hf_dataset_name: OpenAssistant/oasst1 val_split: 0.1 - augment_oasst: - #input_file_path: augmented_latin_cyrillic_oasst_2023-03-27.jsonl input_file_path: augmented_latin_cyrillic_oasst_2023-03-27_v2.jsonl - anthropic_rlhf: fraction: 0.1 @@ -142,7 +140,7 @@ oasst-rm-1-pythia-1.4b: datasets: - oasst_export: lang: "en,es,de,fr" - input_file_path: 2023-03-27_oasst_research_ready_synth.jsonl.gz + hf_dataset_name: OpenAssistant/oasst1 val_split: 0.1 - augment_oasst: input_file_path: augmented_latin_cyrillic_oasst_2023-03-27.jsonl diff --git a/model/model_training/custom_datasets/oasst_dataset.py b/model/model_training/custom_datasets/oasst_dataset.py index 35efd8f4f9..37bf2a71ca 100644 --- a/model/model_training/custom_datasets/oasst_dataset.py +++ b/model/model_training/custom_datasets/oasst_dataset.py @@ -1,8 +1,9 @@ from pathlib import Path -from typing import Literal, Optional +from typing import Iterable, Literal, Optional from model_training.custom_datasets.formatting import DatasetEntrySft, Role, Utterance -from oasst_data import ExportMessageNode, read_message_trees, visit_threads_depth_first +from oasst_data import ExportMessageNode, read_dataset_message_trees, read_message_trees, visit_threads_depth_first +from oasst_data.schemas import ExportMessageTree from torch import Generator from torch.utils.data import Dataset, random_split @@ -20,7 +21,8 @@ def __getitem__(self, index): def load_oasst_export( - input_file_path: str | Path, + input_file_path: Optional[str | Path] = None, + hf_dataset_name: Optional[str] = "OpenAssistant/oasst1", val_split: float = 0.2, lang: str = "en", top_k: Optional[int] = None, @@ -31,20 +33,27 @@ def load_oasst_export( if mode not in ("sft", "rm", "rl"): raise ValueError(f"Unknown dataset mode: {mode}") - lang_codes = lang.split(",") + lang_codes: list[str] = lang.split(",") generator = Generator() generator.manual_seed(manual_seed) - if not isinstance(input_file_path, Path): - input_file_path = Path(input_file_path) - if not input_file_path.is_absolute() and data_path: - if not isinstance(data_path, Path): - data_path = Path(data_path) - input_file_path = data_path / input_file_path + tree_iter: Iterable[ExportMessageTree] = None + if input_file_path: + if not isinstance(input_file_path, Path): + input_file_path = Path(input_file_path) + if not input_file_path.is_absolute() and data_path: + if not isinstance(data_path, Path): + data_path = Path(data_path) + input_file_path = data_path / input_file_path + tree_iter = read_message_trees(input_file_path) + elif hf_dataset_name: + tree_iter = read_dataset_message_trees(hf_dataset_name, split="train+validation") + else: + raise RuntimeError("Either `input_file_path` or `hf_dataset_name` must be specified.") threads_per_tree = [] - for tree in read_message_trees(input_file_path): + for tree in tree_iter: if tree.tree_state != "ready_for_export" or not tree.prompt.review_result or tree.prompt.lang not in lang_codes: continue @@ -145,6 +154,9 @@ def flatten(ds: ListDataset) -> ListDataset: train = flatten(splits[0]) val = flatten(splits[1]) - print(f"OASST data {str(input_file_path)}: {len(train)=}, {len(val)=}") + if input_file_path: + print(f"OASST JSONL file {str(input_file_path)}: {len(train)=}, {len(val)=}") + else: + print(f"OASST HF dataset {hf_dataset_name}: {len(train)=}, {len(val)=}") return train, val diff --git a/oasst-data/oasst_data/__init__.py b/oasst-data/oasst_data/__init__.py index 7137c88bdd..a2b8bc05c3 100644 --- a/oasst-data/oasst_data/__init__.py +++ b/oasst-data/oasst_data/__init__.py @@ -1,4 +1,11 @@ -from oasst_data.reader import read_message_list, read_message_tree_list, read_message_trees, read_messages +from oasst_data.reader import ( + read_dataset_message_trees, + read_dataset_messages, + read_message_list, + read_message_tree_list, + read_message_trees, + read_messages, +) from oasst_data.schemas import ( ExportMessageEvent, ExportMessageEventEmoji, @@ -33,4 +40,6 @@ "visit_messages_depth_first", "write_message_trees", "write_messages", + "read_dataset_message_trees", + "read_dataset_messages", ] diff --git a/oasst-data/oasst_data/reader.py b/oasst-data/oasst_data/reader.py index 0ee129616a..4aaf4f4ba2 100644 --- a/oasst-data/oasst_data/reader.py +++ b/oasst-data/oasst_data/reader.py @@ -4,6 +4,7 @@ from typing import Callable, Iterable, Optional, TextIO import pydantic +from datasets import load_dataset from .schemas import ExportMessageNode, ExportMessageTree @@ -17,22 +18,24 @@ def open_jsonl_read(input_file_path: str | Path) -> TextIO: return input_file_path.open("r", encoding="UTF-8") -def read_oasst_obj(line: str) -> ExportMessageTree | ExportMessageNode: - dict_tree = json.loads(line) +def read_oasst_obj(obj_dict: dict) -> ExportMessageTree | ExportMessageNode: # validate data - if "message_id" in dict_tree: - return pydantic.parse_obj_as(ExportMessageNode, dict_tree) - elif "message_tree_id" in dict_tree: - return pydantic.parse_obj_as(ExportMessageTree, dict_tree) + if "message_id" in obj_dict: + return pydantic.parse_obj_as(ExportMessageNode, obj_dict) + elif "message_tree_id" in obj_dict: + return pydantic.parse_obj_as(ExportMessageTree, obj_dict) raise RuntimeError("Unknown object in jsonl file") -def read_oasst_jsonl(input_file_path: str | Path) -> Iterable[ExportMessageTree | ExportMessageNode]: +def read_oasst_jsonl( + input_file_path: str | Path, +) -> Iterable[ExportMessageTree | ExportMessageNode]: with open_jsonl_read(input_file_path) as file_in: # read one object per line for line in file_in: - yield read_oasst_obj(line) + dict_tree = json.loads(line) + yield read_oasst_obj(dict_tree) def read_message_trees(input_file_path: str | Path) -> Iterable[ExportMessageTree]: @@ -42,11 +45,24 @@ def read_message_trees(input_file_path: str | Path) -> Iterable[ExportMessageTre def read_message_tree_list( - input_file_path: str | Path, filter: Optional[Callable[[ExportMessageTree], bool]] = None + input_file_path: str | Path, + filter: Optional[Callable[[ExportMessageTree], bool]] = None, ) -> list[ExportMessageTree]: return [t for t in read_message_trees(input_file_path) if not filter or filter(t)] +def convert_hf_message(row: dict) -> None: + emojis = row.get("emojis") + if emojis: + row["emojis"] = dict(zip(emojis["name"], emojis["count"])) + labels = row.get("labels") + if labels: + row["labels"] = { + name: {"value": value, "count": count} + for name, value, count in zip(labels["name"], labels["value"], labels["count"]) + } + + def read_messages(input_file_path: str | Path) -> Iterable[ExportMessageNode]: for x in read_oasst_jsonl(input_file_path): assert isinstance(x, ExportMessageNode) @@ -54,6 +70,60 @@ def read_messages(input_file_path: str | Path) -> Iterable[ExportMessageNode]: def read_message_list( - input_file_path: str | Path, filter: Optional[Callable[[ExportMessageNode], bool]] = None + input_file_path: str | Path, + filter: Optional[Callable[[ExportMessageNode], bool]] = None, ) -> list[ExportMessageNode]: return [t for t in read_messages(input_file_path) if not filter or filter(t)] + + +def read_dataset_message_trees( + hf_dataset_name: str = "OpenAssistant/oasst1", + split: str = "train+validation", +) -> Iterable[ExportMessageTree]: + dataset = load_dataset(hf_dataset_name, split=split) + + tree_dict: dict = None + parents: list = None + for row in dataset: + convert_hf_message(row) + if row["parent_id"] is None: + if tree_dict: + tree = read_oasst_obj(tree_dict) + assert isinstance(tree, ExportMessageTree) + yield tree + + tree_dict = { + "message_tree_id": row["message_id"], + "tree_state": row["tree_state"], + "prompt": row, + } + parents = [] + else: + while parents[-1]["message_id"] != row["parent_id"]: + parents.pop() + parent = parents[-1] + if "replies" not in parent: + parent["replies"] = [] + parent["replies"].append(row) + + row.pop("message_tree_id", None) + row.pop("tree_state", None) + parents.append(row) + + if tree_dict: + tree = read_oasst_obj(tree_dict) + assert isinstance(tree, ExportMessageTree) + yield tree + + +def read_dataset_messages( + hf_dataset_name: str = "OpenAssistant/oasst1", + split: str = "train+validation", +) -> Iterable[ExportMessageNode]: + dataset = load_dataset(hf_dataset_name, split=split) + + for row in dataset: + convert_hf_message(row) + message = read_oasst_obj(row) + assert isinstance(message, ExportMessageNode) + yield message diff --git a/oasst-data/pyproject.toml b/oasst-data/pyproject.toml index de81e69847..418517648e 100644 --- a/oasst-data/pyproject.toml +++ b/oasst-data/pyproject.toml @@ -7,7 +7,8 @@ authors = [ ] dependencies = [ "pydantic>=1.10.4", - "loguru==0.6.0" + "loguru==0.6.0", + "datasets>=2.12.0" ] [project.optional-dependencies]