Skip to content

Commit

Permalink
examples: Add DVCLive-HuggingFace notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
daavoo committed Aug 4, 2023
1 parent 24fb584 commit 8cac907
Showing 1 changed file with 167 additions and 0 deletions.
167 changes: 167 additions & 0 deletions examples/DVCLive-HuggingFace.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install accelerate datasets dvclive evaluate 'transformers[torch]' --upgrade"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!git init -q\n",
"!git config --local user.email \"you@example.com\"\n",
"!git config --local user.name \"Your Name\"\n",
"!dvc init -q\n",
"!git commit -m \"DVC init\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"from transformers import AutoTokenizer\n",
"\n",
"dataset = load_dataset(\"imdb\")\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-cased\")\n",
"\n",
"def tokenize_function(examples):\n",
" return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True)\n",
"\n",
"small_train_dataset = dataset[\"train\"].shuffle(seed=42).select(range(2000)).map(tokenize_function, batched=True)\n",
"small_eval_dataset = dataset[\"test\"].shuffle(seed=42).select(range(200)).map(tokenize_function, batched=True)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import evaluate\n",
"\n",
"metric = evaluate.load(\"f1\")\n",
"\n",
"def compute_metrics(eval_pred):\n",
" logits, labels = eval_pred\n",
" predictions = np.argmax(logits, axis=-1)\n",
" return metric.compute(predictions=predictions, references=labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tracking experiments with DVCLive"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dvclive.huggingface import DVCLiveCallback\n",
"from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n",
"\n",
"for epochs in (5, 10, 15):\n",
" model = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-cased\", num_labels=2)\n",
" for param in model.base_model.parameters():\n",
" param.requires_grad = False\n",
"\n",
" training_args = TrainingArguments(\n",
" evaluation_strategy=\"epoch\", \n",
" learning_rate=3e-4,\n",
" logging_strategy=\"epoch\",\n",
" num_train_epochs=epochs,\n",
" output_dir=\"output\", \n",
" overwrite_output_dir=True,\n",
" load_best_model_at_end=True,\n",
" report_to=\"none\",\n",
" save_strategy=\"epoch\",\n",
" weight_decay=0.01,\n",
" )\n",
"\n",
" trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=small_train_dataset,\n",
" eval_dataset=small_eval_dataset,\n",
" compute_metrics=compute_metrics,\n",
" callbacks=[DVCLiveCallback(report=\"notebook\", save_dvc_exp=True, log_model=\"last\")],\n",
" )\n",
" trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Comparing"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import dvc.api\n",
"import pandas as pd\n",
"\n",
"columns = [\"Experiment\", \"epoch\", \"eval.f1\"]\n",
"\n",
"df = pd.DataFrame(dvc.api.exp_show(), columns=columns)\n",
"\n",
"df.dropna(inplace=True)\n",
"df.reset_index(drop=True, inplace=True)\n",
"df\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!dvc plots diff $(dvc exp list --names-only)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import HTML\n",
"HTML(filename='./dvc_plots/index.html')"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 8cac907

Please sign in to comment.