Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question) Fine tunning a transfer learned model #1588

Closed
SyedUmairHassanKazmi opened this issue Feb 22, 2023 · 1 comment
Closed

Question) Fine tunning a transfer learned model #1588

SyedUmairHassanKazmi opened this issue Feb 22, 2023 · 1 comment
Labels
question Further information is requested

Comments

@SyedUmairHassanKazmi
Copy link

I know this is not a forum for help but I couldn't find any forum where I could reach for help.
I have 2 questions.
Q1) How can I fine tune a transfer learned model by unfreezing some layer , or adding new layers and retraining the model in Darts?
Q2) Is their any functionality through which I can plot train and val accuracy with no of epochs. Also something like tensorboard?
Thanks

@SyedUmairHassanKazmi SyedUmairHassanKazmi added the triage Issue waiting for triaging label Feb 22, 2023
@madtoinou madtoinou added question Further information is requested and removed triage Issue waiting for triaging labels Feb 22, 2023
@madtoinou
Copy link
Collaborator

Hi,

Q1)

  1. A PR covering the loading of the weights from a checkpoint to a freshly instantiated model (with new optimizer/learning rate scheduler, discussion can be found here) has been merged yesterday and should cover the retraining aspect.
  2. Darts does not provide a lot of flexibility as for the architecture of a model (new layer, custom activation function) and you would have to implement a new model if you want to be able to add layers (especially after pre-training the model).
  3. As far as I know, it should be possible to (un)freeze layers of darts model as in any PyTorch model.

Q2)

  1. By default, darts models are already generating tensor-board compatible logs (using PyTorch Lightning) in the 'darts_logs/model_name/logs/' folder that can be visualize with the app. You tweak this using the logger keyword of the pl_trainer_kwargs argument, in the models' constructor.

Writing a tutorial notebook for "advanced" model fine-tuning is in the backlog, we are trying to focus on more urgent bug/feature for the moment. If you want to give it a try, contributions are welcome!

PS: You could try to ask your questions on the project gitter, which can be considered as some kind of forum.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants