Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Fix tensorflow import #2481

Merged
merged 3 commits into from
May 25, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ With authors' permission, we listed a set of NNI usage examples and relevant art
Join IM discussion groups:
|Gitter||WeChat|
|----|----|----|
|![image](https://user-images.githubusercontent.com/39592018/80665738-e0574a80-8acc-11ea-91bc-0836dc4cbf89.png)| OR |![image](https://github.com/JSong-Jia/NNI-user-group/blob/master/user%20group%20code_0512.jpg)|
|![image](https://user-images.githubusercontent.com/39592018/80665738-e0574a80-8acc-11ea-91bc-0836dc4cbf89.png)| OR |![image](https://github.com/JSong-Jia/NNI-user-group/blob/master/user%20group%20code_0512.png)|


## Related Projects
Expand Down
1 change: 0 additions & 1 deletion examples/nas/enas-tf/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Licensed under the MIT license.

import tensorflow as tf
from tensorflow.data import Dataset

def get_dataset():
(x_train, y_train), (x_valid, y_valid) = tf.keras.datasets.cifar10.load_data()
Expand Down
13 changes: 6 additions & 7 deletions src/sdk/pynni/nni/nas/tensorflow/enas/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging

import tensorflow as tf
from tensorflow.data import Dataset
from tensorflow.keras.optimizers import Adam

from nni.nas.tensorflow.utils import AverageMeterGroup, fill_zero_grads
Expand Down Expand Up @@ -39,9 +38,9 @@ def __init__(self, model, loss, metrics, reward_function, optimizer, batch_size,

x, y = dataset_train
split = int(len(x) * 0.9)
self.train_set = Dataset.from_tensor_slices((x[:split], y[:split]))
self.valid_set = Dataset.from_tensor_slices((x[split:], y[split:]))
self.test_set = Dataset.from_tensor_slices(dataset_valid)
self.train_set = tf.data.Dataset.from_tensor_slices((x[:split], y[:split]))
self.valid_set = tf.data.Dataset.from_tensor_slices((x[split:], y[split:]))
self.test_set = tf.data.Dataset.from_tensor_slices(dataset_valid)

self.mutator = EnasMutator(model)
self.mutator_optim = Adam(learning_rate=mutator_lr)
Expand Down Expand Up @@ -151,9 +150,9 @@ def validate_one_epoch(self, epoch):


def _create_train_loader(self):
train_set = self.train_set.shuffle(1000000).batch(self.batch_size)
test_set = self.test_set.shuffle(1000000).batch(self.batch_size)
train_set = self.train_set.shuffle(1000000).repeat().batch(self.batch_size)
test_set = self.test_set.shuffle(1000000).repeat().batch(self.batch_size)
return iter(train_set), iter(test_set)

def _create_validate_loader(self):
return iter(self.test_set.shuffle(1000000).batch(self.batch_size))
return iter(self.test_set.shuffle(1000000).repeat().batch(self.batch_size))