Skip to content

Commit

Permalink
Add torch.export test file (#1211)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Sep 29, 2024
1 parent 38c8892 commit 7ba8fa3
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 28 deletions.
82 changes: 73 additions & 9 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -4971,6 +4971,9 @@ python.Execution = class {
_insert(n) {
this._root.prepend(n);
}
output(result, type_expr) {
return this.create_node('output', 'output', new builtins.tuple(result), null, type_expr);
}
_target_to_str(target) {
if (typeof target === 'string') {
if (target.startsWith('__') && target.endswith('__')) {
Expand Down Expand Up @@ -6282,6 +6285,47 @@ python.Execution = class {
this.input_specs = input_specs;
this.output_specs = output_specs;
}
user_inputs() {
const user_inputs = [];
for (const s of this.input_specs) {
if (s.kind !== torch.export.graph_signature.InputKind.USER_INPUT) {
continue;
}
if (s.arg instanceof torch.export.graph_signature.TensorArgument ||
s.arg instanceof torch.export.graph_signature.SymIntArgument ||
s.arg instanceof torch.export.graph_signature.CustomObjArgument) {
user_inputs.push(s.arg.name);
} else if (s.arg instanceof torch.export.graph_signature.ConstantArgument) {
user_inputs.push(s.arg.value);
} else {
throw new python.Error(`Unsupported user input '${s.arg}'.`);
}
}
return user_inputs;
}
user_outputs() {
const user_outputs = [];
for (const s of this.output_specs) {
if (s.kind !== torch.export.graph_signature.OutputKind.USER_OUTPUT) {
continue;
}
if (s.arg instanceof torch.export.graph_signature.TensorArgument ||
s.arg instanceof torch.export.graph_signature.SymIntArgument ||
s.arg instanceof torch.export.graph_signature.CustomObjArgument) {
user_outputs.push(s.arg.name);
} else if (s.arg instanceof torch.export.graph_signature.ConstantArgument) {
user_outputs.push(s.arg.value);
} else {
throw new python.Error(`Unsupported user output '${s.arg}'.`);
}
}
return user_outputs;
}
inputs_to_parameters() {
return new Map(this.input_specs
.filter((s) => s.kind === torch.export.graph_signature.InputKind.PARAMETER && s.arg instanceof torch.export.graph_signature.TensorArgument && typeof s.target === 'string')
.map((s) => [s.arg.name, s.target]));
}
});
torch.export.graph_signature.InputKind = {
USER_INPUT: 0,
Expand Down Expand Up @@ -6606,6 +6650,11 @@ python.Execution = class {
this.input_specs.push(new torch._export.serde.schema.InputSpec({ user_input: { arg: { as_string: user_input } } }));
}
}
if (obj.inputs_to_parameters) {
for (const [input, parameter_name] of Object.entries(obj.inputs_to_parameters)) {
this.input_specs.push(new torch._export.serde.schema.InputSpec({ parameter: { arg: { name: input }, parameter_name } }));
}
}
this.output_specs = [];
if (Array.isArray(obj.output_specs)) {
this.output_specs = obj.output_specs.map((output_spec) => new torch._export.serde.schema.OutputSpec(output_spec));
Expand All @@ -6619,8 +6668,8 @@ python.Execution = class {
});
this.registerType('torch._export.serde.schema.InputToParameterSpec', class {
constructor(obj) {
Object.assign(this, { ...obj });
this.arg = new torch._export.serde.schema.TensorArgument(this.arg);
this.arg = new torch._export.serde.schema.TensorArgument(obj.arg);
this.parameter_name = obj.parameter_name;
}
});
this.registerType('torch._export.serde.schema.InputToBufferSpec', class {
Expand Down Expand Up @@ -6780,15 +6829,15 @@ python.Execution = class {
}
deserialize_graph_output(output) {
if (output.type === 'as_tensor') {
return this.serialized_name_to_node[output.as_tensor.name];
return this.serialized_name_to_node.get(output.as_tensor.name);
} else if (output.type === 'as_sym_int') {
return this.serialized_name_to_node[output.as_sym_int.as_name];
return this.serialized_name_to_node.get(output.as_sym_int.as_name);
} else if (output.type === 'as_sym_bool') {
return this.serialized_name_to_node[output.as_sym_bool.as_name];
return this.serialized_name_to_node.get(output.as_sym_bool.as_name);
} else if (output.type === 'as_int') {
return this.serialized_name_to_node[output.as_int.as_name];
return this.serialized_name_to_node.get(output.as_int.as_name);
} else if (output.type === 'as_none') {
return this.serialized_name_to_node[output.as_sym_bool.as_name];
return this.serialized_name_to_node.get(output.as_sym_bool.as_name);
}
throw new python.Error(`Unsupported graph node ${output.type}.`);
}
Expand Down Expand Up @@ -6825,10 +6874,25 @@ python.Execution = class {
const target = this.deserialize_operator(serialized_node.target);
this.deserialize_node(serialized_node, target);
}
const outputs = [];
let outputs = [];
for (const output of serialized_graph.outputs) {
outputs.push(this.deserialize_graph_output(output));
}
if (serialized_graph.is_single_tensor_return) {
[outputs] = outputs;
} else {
outputs = new builtins.tuple(outputs);
}
const output_node = this.graph.output(outputs);
if (serialized_graph.is_single_tensor_return) {
output_node.meta.set("val", output_node.args[0].meta.get('val'));
} else {
/* output_node.meta["val"] = tuple(
arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg
for arg in output_node.args[0]
) */
}
return self.graph;
}
deserialize_operator(serialized_target) {
let module = null;
Expand Down Expand Up @@ -7077,7 +7141,7 @@ python.Execution = class {
} else if (typ_ === 'as_tensors') {
const result = [];
for (const arg of value) {
result.append(this.serialized_name_to_node.get(arg.name));
result.push(this.serialized_name_to_node.get(arg.name));
}
return result;
} else if (typ_ === 'as_ints' || typ_ === 'as_floats' || typ_ === 'as_bools' || typ_ === 'as_strings') {
Expand Down
12 changes: 6 additions & 6 deletions source/pytorch-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -986,9 +986,9 @@
{ "name": "eps", "type": "float32" }
],
"outputs": [
{ "type": "Tensor" },
{ "type": "Tensor" },
{ "type": "Tensor" }
{ "name": "output", "type": "Tensor" },
{ "name": "save_mean", "type": "Tensor" },
{ "name": "save_rstd", "type": "Tensor" }
]
},
{
Expand All @@ -1004,9 +1004,9 @@
{ "name": "eps", "type": "float32" }
],
"outputs": [
{ "type": "Tensor" },
{ "type": "Tensor" },
{ "type": "Tensor" }
{ "name": "output", "type": "Tensor" },
{ "name": "save_mean", "type": "Tensor" },
{ "name": "save_rstd", "type": "Tensor" }
]
},
{
Expand Down
114 changes: 101 additions & 13 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,75 @@ pytorch.Graph = class {
} else if (type === 'torch.export.exported_program.ExportedProgram' && module.graph) {
const exported_program = module;
const graph = exported_program.graph;
const inputs_to_parameters = exported_program.graph_signature.inputs_to_parameters();
const values = new Map();
values.map = (obj) => {
if (!values.has(obj)) {
let type = null;
const val = obj.meta.get('val');
if (val) {
const dataType = val.dtype.__reduce__();
const shape = new pytorch.TensorShape(val.shape);
type = new pytorch.TensorType(dataType, shape);
}
const value = new pytorch.Value(obj.name, type);
values.set(obj, value);
}
return values.get(obj);
};
const nodes = new Map(graph.nodes.map((node) => [node.name, node]));
for (const node of graph.nodes) {
this.nodes.push(new pytorch.Node(metadata, node.name, null, node, null, values));
if (node.op === 'placeholder') {
if (inputs_to_parameters.has(node.name)) {
const key = inputs_to_parameters.get(node.name);
const parameter = exported_program.state_dict.get(key);
if (parameter) {
const tensor = new pytorch.Tensor(key, parameter.data);
const value = new pytorch.Value(key, null, null, tensor);
values.set(node, value);
}
}
}
}
for (const obj of graph.nodes) {
if (obj.op === 'placeholder' && obj.users.size <= 1) {
continue;
}
if (obj.op === 'call_function') {
if (obj.target.__module__ === 'operator' && obj.target.__name__ === 'getitem') {
continue;
}
if (obj.users.size === 0) {
continue;
}
}
if (obj.op === 'output') {
for (const output of obj.args) {
if (output.op === 'call_function' && output.target.__module__ === 'operator' && output.target.__name__ === 'getitem') {
continue;
}
const value = values.map(output);
const argument = new pytorch.Argument(output.name, [value]);
this.outputs.push(argument);
}
continue;
}
const node = new pytorch.Node(metadata, obj.name, null, obj, null, values);
this.nodes.push(node);
}
for (const input_spec of exported_program.graph_signature.user_inputs()) {
const node = nodes.get(input_spec);
const value = values.map(node);
const argument = new pytorch.Argument(input_spec, [value]);
this.inputs.push(argument);
}
/*
for (const output_spec of exported_program.graph_signature.user_outputs()) {
const value = values.map(output_spec);
const argument = new pytorch.Argument(output_spec, [value]);
this.outputs.push(argument);
}
*/
} else if (pytorch.Utility.isTensor(module)) {
const node = new pytorch.Node(metadata, null, type, { value: module });
this.nodes.push(node);
Expand Down Expand Up @@ -420,26 +486,48 @@ pytorch.Node = class {
if (obj.op === 'call_function') {
this.type = createType(metadata, obj.target.name);
const schema = obj.target._schema;
const args = obj.args;
for (let i = 0; i < args.length; i++) {
const arg = args[i];
const name = schema && Array.isArray(schema.arguments) ? schema.arguments[i].name : '';
let args = obj.args.map((arg, index) => {
const name = schema && Array.isArray(schema.arguments) ? schema.arguments[index].name : '';
return [name, arg];
});
args = args.concat(Array.from(obj.kwargs));
for (const [name, arg] of args) {
if (pytorch.Utility.isInstance(arg, 'torch.fx.node.Node')) {
this.inputs.push(new pytorch.Argument(name, [values.map(arg.name)]));
const value = values.map(arg);
const argument = new pytorch.Argument(name, [value]);
this.inputs.push(argument);
} else if (Array.isArray(arg) && arg.every((arg) => pytorch.Utility.isInstance(arg, 'torch.fx.node.Node'))) {
const list = arg.map((arg) => values.map(arg));
const argument = new pytorch.Argument(name, list);
this.inputs.push(argument);
} else {
this.inputs.push(new pytorch.Argument(name, arg, 'attribute'));
const argument = new pytorch.Argument(name, arg, 'attribute');
this.inputs.push(argument);
}
}
for (const [name, arg] of obj.kwargs) {
if (pytorch.Utility.isInstance(arg, 'torch.fx.node.Node')) {
this.inputs.push(new pytorch.Argument(name, [values.map(arg.name)]));
} else {
this.inputs.push(new pytorch.Argument(name, arg, 'attribute'));
let outputs = [obj];
if (obj.users.size > 1) {
const users = Array.from(obj.users.keys());
if (users.every((user) => user.op === 'call_function' && user.target.__module__ === 'operator' && user.target.__name__ === 'getitem')) {
outputs = new Array(obj.users.size);
for (const user of users) {
const [, index] = user.args;
outputs[index] = user;
}
}
}
this.outputs.push(new pytorch.Argument('output', [values.map(obj.name)]));
for (let i = 0; i < outputs.length; i++) {
const node = outputs[i];
const value = values.map(node);
const name = schema && schema.returns && schema.returns[i] ? schema.returns[i].name || 'output' : 'output';
const argument = new pytorch.Argument(name, [value]);
this.outputs.push(argument);
}
} else if (obj.op === 'placeholder') {
this.type = createType(metadata, 'placeholder');
const value = values.map(obj);
const argument = new pytorch.Argument('value', [value]);
this.inputs.push(argument);
} else {
throw new pytorch.Error(`Unsupported node operation '${obj.op}'.`);
}
Expand Down
7 changes: 7 additions & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -5327,6 +5327,13 @@
"format": "TorchScript v1.6",
"link": "https://github.com/lutzroeder/netron/issues/281"
},
{
"type": "pytorch",
"target": "inception_v3.pt2",
"source": "https://github.com/user-attachments/files/17179803/inception_v3.pt2.zip[inception_v3.pt2]",
"format": "PyTorch Export v5.3",
"link": "https://github.com/lutzroeder/netron/issues/281"
},
{
"type": "pytorch",
"target": "inception_v3.pkl.pth.zip",
Expand Down

0 comments on commit 7ba8fa3

Please sign in to comment.