Skip to content

Commit

Permalink
chore: gofmt, black, and prettier run across PBT changes
Browse files Browse the repository at this point in the history
  • Loading branch information
a9p committed Apr 10, 2022
1 parent dec0729 commit f6890a0
Show file tree
Hide file tree
Showing 11 changed files with 408 additions and 186 deletions.
67 changes: 44 additions & 23 deletions examples/v1beta1/trial-images/simple-pbt/pbt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import tensorflow as tf
import time

class PBTBenchmarkExample():

class PBTBenchmarkExample:
"""Toy PBT problem for benchmarking adaptive learning rate.
The goal is to optimize this trainable's accuracy. The accuracy increases
fastest at the optimal lr, which is a function of the current accuracy.
Expand All @@ -36,24 +37,23 @@ def __init__(self, lr, log_dir: str, log_interval: int, checkpoint: str):
self._log_interval = log_interval
self._lr = lr

self._checkpoint_file = os.path.join(checkpoint, 'training.ckpt')
self._checkpoint_file = os.path.join(checkpoint, "training.ckpt")
if os.path.exists(self._checkpoint_file):
with open(self._checkpoint_file, 'rb') as fin:
with open(self._checkpoint_file, "rb") as fin:
checkpoint_data = pickle.load(fin)
self._accuracy = checkpoint_data['accuracy']
self._step = checkpoint_data['step']
self._accuracy = checkpoint_data["accuracy"]
self._step = checkpoint_data["step"]
else:
os.makedirs(checkpoint, exist_ok=True)
self._step = 1
self._accuracy = 0.0


def save_checkpoint(self):
with open(self._checkpoint_file, 'wb') as fout:
pickle.dump({'step': self._step, 'accuracy': self._accuracy}, fout)
with open(self._checkpoint_file, "wb") as fout:
pickle.dump({"step": self._step, "accuracy": self._accuracy}, fout)

def step(self):
midpoint = 100 # lr starts decreasing after acc > midpoint
midpoint = 100 # lr starts decreasing after acc > midpoint
q_tolerance = 3 # penalize exceeding lr by more than this multiple
noise_level = 2 # add gaussian noise to the acc increase
# triangle wave:
Expand All @@ -80,32 +80,53 @@ def step(self):
if not self._writer:
self._writer = tf.summary.create_file_writer(self._log_dir)
with self._writer.as_default():
tf.summary.scalar("Validation-accuracy", self._accuracy, step=self._step)
tf.summary.scalar(
"Validation-accuracy", self._accuracy, step=self._step
)
tf.summary.scalar("lr", self._lr, step=self._step)
self._writer.flush()

self._step += 1

def __repr__(self):
return "epoch {}:\nlr={:0.4f}\nValidation-accuracy={:0.4f}".format(self._step, self._lr, self._accuracy)
return "epoch {}:\nlr={:0.4f}\nValidation-accuracy={:0.4f}".format(
self._step, self._lr, self._accuracy
)


if __name__ == "__main__":
# Parse CLI arguments
parser = argparse.ArgumentParser(description='PBT Basic Test')
parser.add_argument('--lr', type=float, default=0.0001,
help='learning rate (default: 0.0001)')
parser.add_argument('--epochs', type=int, default=20,
help='number of epochs to train (default: 20)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status (default: 1)')
parser.add_argument('--log-path', type=str, default="/var/log/katib/tfevent/",
help='tfevent output path (default: /var/log/katib/tfevent/)')
parser.add_argument('--checkpoint', type=str, default="/var/log/katib/checkpoints/",
help='checkpoint directory (resume and save)')
parser = argparse.ArgumentParser(description="PBT Basic Test")
parser.add_argument(
"--lr", type=float, default=0.0001, help="learning rate (default: 0.0001)"
)
parser.add_argument(
"--epochs", type=int, default=20, help="number of epochs to train (default: 20)"
)
parser.add_argument(
"--log-interval",
type=int,
default=10,
metavar="N",
help="how many batches to wait before logging training status (default: 1)",
)
parser.add_argument(
"--log-path",
type=str,
default="/var/log/katib/tfevent/",
help="tfevent output path (default: /var/log/katib/tfevent/)",
)
parser.add_argument(
"--checkpoint",
type=str,
default="/var/log/katib/checkpoints/",
help="checkpoint directory (resume and save)",
)
opt = parser.parse_args()

benchmark = PBTBenchmarkExample(opt.lr, opt.log_path, opt.log_interval, opt.checkpoint)
benchmark = PBTBenchmarkExample(
opt.lr, opt.log_path, opt.log_interval, opt.checkpoint
)
for i in range(opt.epochs):
benchmark.step()
time.sleep(0.2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func (g *General) SyncAssignments(
assignment.Annotations[a.Name] = a.Value
}
}
trialAssignments = append(trialAssignments, assignment)
trialAssignments = append(trialAssignments, assignment)
}

for _, t := range trialAssignments {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@ import { PbtTabComponent } from './pbt-tab.component';

@NgModule({
declarations: [PbtTabComponent],
imports: [CommonModule, FormsModule, MatFormFieldModule, MatSelectModule, MatCheckboxModule],
imports: [
CommonModule,
FormsModule,
MatFormFieldModule,
MatSelectModule,
MatCheckboxModule,
],
exports: [PbtTabComponent],
})
export class PbtTabModule {}
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
<div class="pbt-wrapper">

<div class="pbt-options-wrapper">
<mat-form-field appearance="fill" class="pbt-option">
<mat-label>Y-Axis</mat-label>
<mat-select [(ngModel)]="selectedName" (ngModelChange)="onDropdownChange()">
<mat-select
[(ngModel)]="selectedName"
(ngModelChange)="onDropdownChange()"
>
<mat-option *ngFor="let name of selectableNames" [value]="name">
{{name}}
{{ name }}
</mat-option>
</mat-select>
</mat-form-field>

<mat-checkbox [(ngModel)]="displayTrace" (ngModelChange)="onTraceChange()">Display Seed Traces</mat-checkbox>
<mat-checkbox [(ngModel)]="displayTrace" (ngModelChange)="onTraceChange()"
>Display Seed Traces</mat-checkbox
>
</div>

<div #pbtGraph id="pbt-graph" class="d3-tab-graph"></div>

</div>
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
}

.pbt-option {
margin: 10px
margin: 10px;
}

.d3-tab-graph {
Expand Down
Loading

0 comments on commit f6890a0

Please sign in to comment.