Skip to content

Commit

Permalink
Merge pull request #26 from shreydan/vit-transfer-learning-od
Browse files Browse the repository at this point in the history
Transfer learning sub section in the fine tuning notebook
  • Loading branch information
shreydan authored Feb 7, 2024
2 parents 33edf14 + db943e1 commit 1ebfc9d
Showing 1 changed file with 52 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
"\twidth=\"850\"\n",
"\theight=\"450\">\n",
"</iframe>\n",
"```"
"```\n",
"\n",
"Also there is a small section if you are interested in Transfer learning instead of fine tuning only."
]
},
{
Expand Down Expand Up @@ -1526,6 +1528,55 @@
"source": [
"Well, that's not bad. We can improve the results if we fine-tune further. You can find this fine-tuned checkpoint [here](hf-vision/detr-resnet-50-dc5-harhat-finetuned). "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## How about Transfer learning ?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this notebook , we primarily discussed about the fine-tuning a certain model to our custom dataset. What if , we only want transfer learning? Actually that is easy peasy! In transfer learning , we have to keep the parameter values aka weights, of the pretrained model frozen. We just train the classifier layer (in some cases, one or two more layers). In this case before starting the training process, we can do the following, "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"from transformers import AutoModelForObjectDetection\n",
"\n",
"id2label = {0:'head', 1:'helmet', 2:'person'}\n",
"label2id = {v: k for k, v in id2label.items()}\n",
"\n",
"\n",
"model = AutoModelForObjectDetection.from_pretrained(\n",
" checkpoint,\n",
" id2label=id2label,\n",
" label2id=label2id,\n",
" ignore_mismatched_sizes=True,\n",
")\n",
"\n",
"for name,p in model.named_parameters():\n",
" if not 'bbox_predictor' in name or not name.startswith('class_label'):\n",
" p.requires_grad = False"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That means, after loading the model , we freeze all of the layers except last 6 layers. Which are `bbox_predictor.layers` and `class_labels_classifier`."
]
}
],
"metadata": {
Expand Down

0 comments on commit 1ebfc9d

Please sign in to comment.