Skip to content

Commit

Permalink
Add tolerable_loss to TuningConfig (#1579)
Browse files Browse the repository at this point in the history
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>
  • Loading branch information
Kaihui-intel authored Jan 30, 2024
1 parent 7bf89eb commit fb61428
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 5 deletions.
69 changes: 66 additions & 3 deletions neural_compressor/common/base_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,57 @@ class TuningConfig:
config_set: quantization configs. Default value is empty.
timeout: Tuning timeout (seconds). Default value is 0 which means early stop.
max_trials: Max tuning times. Default value is 100. Combine with timeout field to decide when to exit.
tolerable_loss: This float indicates how much metric loss we can accept. \
The metric loss is relative, it can be both positive and negative. Default is 0.01.
Examples:
from neural_compressor import TuningConfig
tune_config = TuningConfig(
config_set=[config1, config2, ...],
max_trials=3,
tolerable_loss=0.01
)
# Case 1: Tolerable Loss
fp32_baseline = 100
config1_metric, config2_metric, ... = 98, 99, ...
# Tuning result of case 1:
# The best tuning config is config2, because config2_metric >= fp32_baseline * (1 - tolerable_loss)
# Case 2: Maximum Trials
fp32_baseline = 100
config1_metric, config2_metric, config3_metric, ... = 98, 98, 97, ...
# Tuning result of case 2:
# The best tuning config is config2, because of the following:
# 1. Not achieving the set goal. (config_metric < fp32_baseline * (1 - tolerable_loss))
# 2. Reached maximum tuning times.
# Case 3: Timeout
tune_config = TuningConfig(
config_set=[config1, config2, ...],
timeout=10, # seconds
max_trials=3,
tolerable_loss=0.01
)
config1_tuning_time, config2_tuning_time, config3_tuning_time, ... = 4, 5, 6, ... # seconds
fp32_baseline = 100
config1_metric, config2_metric, config3_metric, ... = 98, 98, 97, ...
# Tuning result of case 3:
# The best tuning config is config2, due to timeout, the third trial was forced to exit.
"""

def __init__(self, config_set=None, timeout=0, max_trials=100, sampler: Sampler = None) -> None:
def __init__(
self, config_set=None, timeout=0, max_trials=100, sampler: Sampler = None, tolerable_loss=0.01
) -> None:
"""Init a TuneCriterion object."""
self.config_set = config_set
self.timeout = timeout
self.max_trials = max_trials
self.sampler = sampler
self.tolerable_loss = tolerable_loss


class _TrialRecord:
Expand All @@ -242,12 +285,17 @@ def __init__(self, tuning_config: TuningConfig) -> None:
self.tuning_config = tuning_config
self.trial_cnt = 0
self.tuning_history: List[_TrialRecord] = []
self.baseline = None

def add_trial_result(self, trial_index: int, trial_result: Union[int, float], quant_config: BaseConfig) -> None:
self.trial_cnt += 1
trial_record = _TrialRecord(trial_index, trial_result, quant_config)
self.tuning_history.append(trial_record)

def set_baseline(self, baseline: float):
self.baseline = baseline
logger.info(f"Fp32 baseline is {self.baseline}")

def get_number_of_trials(self):
return len(self.tuning_history)

Expand All @@ -260,8 +308,23 @@ def get_best_quant_config(self) -> BaseConfig:
return sorted_trials_records[0].quant_config

def need_stop(self) -> bool:
# TODO Support more stop criteria in the next PR, such as `reach accuracy goal`, `timeout`, and so on.
return self.trial_cnt >= self.tuning_config.max_trials
"""Check if need to stop tuning. Either accuracy goal is met, max trials is reached or timeout is reached.
Returns:
bool: True if need to stop, otherwise False.
"""

# TODO: Support more stop criteria in the next PR, such as `timeout`, and so on.
# reach max trials
reach_max_trials = self.trial_cnt >= self.tuning_config.max_trials
# reach accuracy goal
meet_accuracy_goal = (
False
if self.baseline is None
else self.tuning_history[-1].trial_result >= (self.baseline * (1 - self.tuning_config.tolerable_loss))
)
# [-1] is the last element representing the latest trail record.
return reach_max_trials or meet_accuracy_goal


def init_tuning(tuning_config: TuningConfig) -> Tuple[ConfigLoader, TuningLogger, TuningMonitor]:
Expand Down
5 changes: 4 additions & 1 deletion neural_compressor/torch/quantization/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def autotune(
evaluator.set_eval_fn_registry(eval_fns)
evaluator.self_check()
config_loader, tuning_logger, tuning_monitor = init_tuning(tuning_config=tune_config)
baseline: float = evaluator.evaluate(model)
tuning_monitor.set_baseline(baseline)
tuning_logger.tuning_start()
for trial_index, quant_config in enumerate(config_loader):
tuning_logger.trial_start(trial_index=trial_index)
Expand All @@ -58,11 +60,12 @@ def autotune(
eval_result: float = evaluator.evaluate(q_model)
tuning_logger.evaluation_end()
tuning_monitor.add_trial_result(trial_index, eval_result, quant_config)
tuning_logger.trial_end(trial_index)
if tuning_monitor.need_stop():
best_quant_config: BaseConfig = tuning_monitor.get_best_quant_config()
# !!! Make sure to use deepcopy only when inplace is set to `True`.
quantize(deepcopy(model), quant_config=best_quant_config, run_fn=run_fn, run_args=run_args, inplace=True)
best_quant_model = model # quantize model inplace
tuning_logger.trial_end(trial_index)
break
tuning_logger.tuning_end()
return best_quant_model
39 changes: 38 additions & 1 deletion test/3x/common/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@

from typing import Any, Callable, List, Optional, Tuple, Union

from neural_compressor.common.base_config import BaseConfig, get_all_config_set_from_config_registry, register_config
from neural_compressor.common.base_config import (
BaseConfig,
ComposableConfig,
get_all_config_set_from_config_registry,
register_config,
)
from neural_compressor.common.base_tuning import ConfigLoader, Sampler
from neural_compressor.common.utils import DEFAULT_WHITE_LIST, OP_NAME_OR_MODULE_TYPE

PRIORITY_FAKE_ALGO = 100
Expand Down Expand Up @@ -137,5 +143,36 @@ def test_api(self):
self.assertEqual(config_set[0].weight_bits, DEFAULT_WEIGHT_BITS)


class TestConfigLoader(unittest.TestCase):
def setUp(self):
self.config_set = [get_default_fake_config(), get_default_fake_config()]
self.loader = ConfigLoader(self.config_set, Sampler())

def test_parse_quant_config_single(self):
quant_config = get_default_fake_config()
result = ConfigLoader.parse_quant_config(quant_config)
self.assertEqual(str(result), str(quant_config.expand()))

def test_parse_quant_config_composable(self):
quant_config = get_default_fake_config()
composable_config = ComposableConfig(get_default_fake_config())
composable_config.config_list = [quant_config]
result = ConfigLoader.parse_quant_config(composable_config)
self.assertEqual(str(result), str(quant_config.expand()))

def test_parse_quant_configs(self):
quant_configs = [get_default_fake_config(), get_default_fake_config()]
self.config_set[0].expand = lambda: quant_configs
self.config_set[1].expand = lambda: []
result = self.loader.parse_quant_configs()
self.assertEqual(result, quant_configs)

def test_iteration(self):
quant_configs = [get_default_fake_config(), get_default_fake_config()]
self.loader.parse_quant_configs = lambda: quant_configs
result = list(self.loader)
self.assertEqual(result, quant_configs)


if __name__ == "__main__":
unittest.main()
58 changes: 58 additions & 0 deletions test/3x/common/test_utility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Tests for common components.
!!! Please do not import any framework-specific modules in this file. !!!
* Note, we may need to add some auto check mechanisms to ensure this.
These tests aim to assess the fundamental functionalities of common utils and enhance code coverage.
All tests will be included for each framework CI.
"""
import unittest

from neural_compressor.common import options
from neural_compressor.common.utils import set_random_seed, set_resume_from, set_tensorboard, set_workspace


class TestOptions(unittest.TestCase):
def test_set_random_seed(self):
seed = 12345
set_random_seed(seed)
self.assertEqual(options.random_seed, seed)

# non int type
seed = "12345"
with self.assertRaises(AssertionError):
set_random_seed(seed)

def test_set_workspace(self):
workspace = "/path/to/workspace"
set_workspace(workspace)
self.assertEqual(options.workspace, workspace)

# non String type
workspace = 12345
with self.assertRaises(AssertionError):
set_workspace(workspace)

def test_set_resume_from(self):
resume_from = "/path/to/resume"
set_resume_from(resume_from)
self.assertEqual(options.resume_from, resume_from)

# non String type
resume_from = 12345
with self.assertRaises(AssertionError):
set_resume_from(resume_from)

def test_set_tensorboard(self):
tensorboard = True
set_tensorboard(tensorboard)
self.assertEqual(options.tensorboard, tensorboard)

# non bool type
tensorboard = 123
with self.assertRaises(AssertionError):
set_tensorboard(tensorboard)


if __name__ == "__main__":
unittest.main()
39 changes: 39 additions & 0 deletions test/3x/torch/test_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,45 @@ def test_autotune_not_eval_func(self):
str(context.exception), "Please ensure that you register at least one evaluation metric for auto-tune."
)

def test_autotune_baseline(self):
logger.info("test_autotune_api")
from neural_compressor.common.base_tuning import evaluator

baseline = [1.0]

# case 1
# Where default tolerable_loss is 0.01, we expect the tuning to end with a "2-trail end" output logged.
acc_res_lst = baseline + [0.9] * 2 + [0.99]

def eval_acc_fn(model):
res = acc_res_lst.pop(0)
return res

custom_tune_config = TuningConfig(config_set=[RTNConfig(bits=[4, 6, 5, 8])], max_trials=6)
best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=eval_acc_fn)
self.assertIsNotNone(best_model)

# case 2
# Where tolerable_loss is 0.1, we expect the tuning to end with a "0-trail end" output logged.
acc_res_lst = baseline + [0.9] * 2 + [0.99] + [1.01]
custom_tune_config = TuningConfig(config_set=[RTNConfig(bits=[4, 6, 5, 8])], tolerable_loss=0.1)
best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=eval_acc_fn)
self.assertIsNotNone(best_model)

# case 3
# Where tolerable_loss is -0.01, we expect the tuning to end with a "3-trail end" output logged.
acc_res_lst = baseline + [0.9] * 2 + [0.99] + [1.01]
custom_tune_config = TuningConfig(config_set=[RTNConfig(bits=[4, 6, 5, 8])], tolerable_loss=-0.01)
best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=eval_acc_fn)
self.assertIsNotNone(best_model)

# case 4
# Where tolerable_loss is 0.01 and accuracy meets the goal, we expect best model is None.
acc_res_lst = baseline + [0.9] * 2 + [0.9] + [0.9]
custom_tune_config = TuningConfig(config_set=[RTNConfig(bits=[4, 6, 5, 8])], tolerable_loss=0.01)
best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=eval_acc_fn)
self.assertIsNone(best_model)


if __name__ == "__main__":
unittest.main()

0 comments on commit fb61428

Please sign in to comment.