Skip to content

Commit

Permalink
allow set workspace path for inc quantizer (#987)
Browse files Browse the repository at this point in the history
## Describe your changes

## Checklist before requesting a review
- [ ] Add unit tests for this change.
- [ ] Make sure all tests can pass.
- [ ] Update documents if necessary.
- [ ] Lint and apply fixes to your code by running `lintrunner -a`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.
- [ ] Is this PR including examples changes? If yes, please remember to
update [example
documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md)
in a follow-up PR.

## (Optional) Issue link
  • Loading branch information
guotuofeng authored Mar 6, 2024
1 parent 603e8a7 commit fed7c66
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
13 changes: 13 additions & 0 deletions olive/passes/onnx/inc_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@
If users set domain as auto, automatic detection for domain will be executed.
""",
),
"workspace": PassConfigParam(
type_=str,
default_value=None,
description="""Workspace for Intel® Neural Compressor quantization where intermediate files and
tuning history file are stored. Default value is:
"./nc_workspace/{}/".format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))""",
),
"recipes": PassConfigParam(
type_=dict,
default_value={},
Expand Down Expand Up @@ -505,6 +512,12 @@ def _run_for_config(
eval_func, accuracy_criterion, tuning_criterion = self._set_tuning_config(run_config, data_root)
weight_only_config = self._set_woq_config(run_config)

workspace = run_config.pop("workspace", None)
if workspace:
from neural_compressor import set_workspace

set_workspace(workspace)

# keys not needed for quantization
to_delete = [
"script_dir",
Expand Down
9 changes: 6 additions & 3 deletions test/unit_test/workflows/test_whisper_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def download_audio_test_data(tmp_path):


@pytest.fixture(name="whisper_config")
def prepare_whisper_config(audio_data):
def prepare_whisper_config(audio_data, tmp_path):
return {
"input_model": {
"type": "PyTorchModel",
Expand Down Expand Up @@ -79,6 +79,9 @@ def prepare_whisper_config(audio_data):
"inc_dynamic_quantization": {
"type": "IncDynamicQuantization",
"disable_search": True,
"config": {
"workspace": str(tmp_path / "workspace"),
},
},
"insert_beam_search": {
"type": "InsertBeamSearch",
Expand All @@ -102,8 +105,8 @@ def prepare_whisper_config(audio_data):
"evaluate_input_model": False,
"execution_providers": ["CPUExecutionProvider"],
"clean_cache": True,
"cache_dir": "cache",
"output_dir": "models",
"cache_dir": str(tmp_path / "cache"),
"output_dir": str(tmp_path / "models"),
"output_name": "whisper_cpu_fp32",
},
}
Expand Down

0 comments on commit fed7c66

Please sign in to comment.