Skip to content

Commit

Permalink
Add convert fp16 io to fp32 pass to webGPU example
Browse files Browse the repository at this point in the history
  • Loading branch information
devang-ml authored and xiaoyu-work committed May 13, 2024
1 parent f6fe26b commit 601f83b
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
2 changes: 2 additions & 0 deletions examples/phi2/phi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def main(raw_args=None):
template_json["systems"]["local_system"]["config"]["accelerators"] = [
{"device": "GPU", "execution_providers": ["JsExecutionProvider"]}
]
fl_type = { "type" : "OnnxIOFloat16ToFloat32"}
template_json["passes"]["fp32_logits"] = fl_type
new_json_file = f"phi2_web.json"
with open(new_json_file, "w") as f:
json.dump(template_json, f, indent=4)
Expand Down
2 changes: 1 addition & 1 deletion olive/olive_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"module_path": "olive.passes.onnx.float16_conversion.OnnxFloatToFloat16"
},
"OnnxIOFloat16ToFloat32": {
"module_path": "olive.passes.onnx.float16_conversion.OnnxIOFloat16ToFloat32"
"module_path": "olive.passes.onnx.float32_conversion.OnnxIOFloat16ToFloat32"
},
"OnnxMatMul4Quantizer": {
"module_path": "olive.passes.onnx.quantization.OnnxMatMul4Quantizer"
Expand Down
22 changes: 13 additions & 9 deletions olive/passes/onnx/float32_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@ class OnnxIOFloat16ToFloat32(Pass):
def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]:
config = {
"name_pattern": PassConfigParam(
type_=List[str], default_value="logits", description="Only convert inputs/outputs whose name matches this pattern"
type_=str, default_value="logits",
description=(
"Only convert inputs/outputs whose name matches this pattern. By default"
"looking for logits names"
)
)
}
config.update(get_external_data_config())
return config

def create_io_mapping(graph, i_map, o_map):
def create_io_mapping(self, graph, i_map, o_map):
for n in graph.node:
for i in n.input:
i_map[i].append(n)
Expand All @@ -39,7 +43,7 @@ def create_io_mapping(graph, i_map, o_map):
assert o not in o_map[o]
o_map[o] = [n]

def wrap_inputs(graph, i_map, names):
def wrap_inputs(self, graph, i_map, names):
# 1. find fp16 inputs
# 2. rewrite all consumers
# 3. insert cast
Expand All @@ -65,7 +69,7 @@ def wrap_inputs(graph, i_map, names):
i.type.tensor_type.elem_type = onnx.TensorProto.FLOAT


def wrap_outputs(graph, i_map, o_map, names):
def wrap_outputs(self, graph, i_map, o_map, names):
# 1. find fp16 outputs
# 2. rewrite all providers
# 3. append cast
Expand Down Expand Up @@ -107,14 +111,14 @@ def _run_for_config(
i_map = defaultdict(list)
o_map = defaultdict(list)

self.create_io_mapping(model.graph, i_map, o_map)
self.create_io_mapping(ort_onnx_model.model.graph, i_map, o_map)

pat = None
if args.name:
pat = re.compile(args.name)
if config["name_pattern"]:
pat = re.compile(config["name_pattern"])

self.wrap_inputs(model.graph, i_map, pat)
self.wrap_outputs(model.graph, i_map, o_map, pat)
self.wrap_inputs(ort_onnx_model.model.graph, i_map, pat)
self.wrap_outputs(ort_onnx_model.model.graph, i_map, o_map, pat)

# save the model to the output path and return the model
return model_proto_to_olive_model(ort_onnx_model.model, output_model_path, config)

0 comments on commit 601f83b

Please sign in to comment.