From 89ad09b0134c24af647c522f2d62157b7eedd997 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Thu, 26 Sep 2024 06:20:29 -0700 Subject: [PATCH] Update python.js (#1211) --- source/python.js | 115 ++++++++++++++++++++++++++++++++++++---------- source/pytorch.js | 17 +++++-- 2 files changed, 105 insertions(+), 27 deletions(-) diff --git a/source/python.js b/source/python.js index 9b639ba232..e45b41c325 100644 --- a/source/python.js +++ b/source/python.js @@ -4865,13 +4865,15 @@ python.Execution = class { return chars.join(''); } }); - torch.fx.Graph = torch.fx.graph.Graph; this.registerType('torch.fx.graph_module.GraphModule', class extends torch.nn.modules.module.Module { - constructor(root, graph) { + constructor(root, graph, class_name) { super(); + this.__class__.__name__ = class_name || 'GraphModule'; this.graph = graph; } }); + torch.fx.Graph = torch.fx.graph.Graph; + torch.fx.GraphModule = torch.fx.graph_module.GraphModule; this.registerType('torch.fx.immutable_collections.immutable_dict', class extends builtins.dict {}); this.registerFunction('torch.fx._symbolic_trace.wrap', (fn_or_name) => { return fn_or_name; @@ -5988,7 +5990,45 @@ python.Execution = class { this.registerType('torch.export.graph_signature.OutputKind', class extends this._enum.Enum {}); this.registerType('torch.export.graph_signature.OutputSpec', class extends this._enum.Enum {}); this.registerType('torch.export.graph_signature.TensorArgument', class {}); - this.registerType('torch.export.exported_program.ExportedProgram', class {}); + this.registerType('torch.export.exported_program.ExportedProgram', class { + constructor(root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs, verifier, tensor_constants, constants) { + // graph._codegen = torch.fx.graph.CodeGen() + this._graph_module = this._create_graph_module_for_export(root, graph); + if (root instanceof torch.fx.GraphModule) { + // this._graph_module.meta.update(root.meta); + } + this._graph_signature = graph_signature; + this._state_dict = state_dict; + this._range_constraints = range_constraints; + this._module_call_graph = module_call_graph; + this._example_inputs = example_inputs; + this._constants = tensor_constants || constants || {}; + + /* + graph: torch.fx.Graph, + graph_signature: ExportGraphSignature, + state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]], + range_constraints: "Dict[sympy.Symbol, Any]", + module_call_graph: List[ModuleCallEntry], + example_inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None, + constants: Optional[ + Dict[str, Union[torch.Tensor, FakeScriptObject, torch._C.ScriptObject]] + ] = None, + *, + verifiers: Optional[List[Type[Verifier]]] = None, + */ + } + _create_graph_module_for_export(root, graph) { + let gm = null; + try { + gm = new torch.fx.GraphModule(root, graph); + } catch { + const gm = new torch.fx.GraphModule(root, torch.fx.Graph()); + gm._graph = graph; + } + return gm; + } + }); this.registerType('torch.export.exported_program.ModuleCallEntry', class {}); this.registerType('torch.export.exported_program.ModuleCallSignature', class {}); this.registerFunction('torch.export.unflatten'); @@ -5996,17 +6036,21 @@ python.Execution = class { return new torch.fx.graph_module.GraphModule(root, graph); }); this.registerType('torch._export.serde.serialize.SerializedArtifact', class { - constructor(exported_program, state_dict, constants) { + constructor(exported_program, state_dict, constants, example_inputs) { this.exported_program = exported_program; this.state_dict = state_dict; this.constants = constants; + this.example_inputs = example_inputs; } }); + this.registerType('torch._export.serde.schema.SymInt', class { + }); this.registerFunction('torch._export.load', (f, expected_opset_version) => { const serialized_exported_program = f.get('serialized_exported_program.json'); const serialized_state_dict = f.get('serialized_state_dict.pt'); const serialized_constants = f.get('serialized_constants.pt'); - const artifact = new torch._export.serde.serialize.SerializedArtifact(serialized_exported_program, serialized_state_dict, serialized_constants); + const serialized_example_inputs = f.get('serialized_example_inputs.pt'); + const artifact = new torch._export.serde.serialize.SerializedArtifact(serialized_exported_program, serialized_state_dict, serialized_constants, serialized_example_inputs); return torch._export.serde.serialize.deserialize(artifact, expected_opset_version); }); this.registerFunction('torch._export.serde.serialize._dict_to_dataclass', (cls, data) => { @@ -6033,25 +6077,28 @@ python.Execution = class { return data; }); this.registerFunction('torch._export.serde.serialize.deserialize', (artifact, expected_opset_version) => { - artifact.exported_program = torch._export.serde.serialize._dict_to_dataclass(null, artifact.exported_program); - return new torch._export.serde.serialize.ExportedProgramDeserializer(expected_opset_version).deserialize(artifact); + const serialized_exported_program = torch._export.serde.serialize._dict_to_dataclass(null, artifact.exported_program); + return new torch._export.serde.serialize.ExportedProgramDeserializer(expected_opset_version).deserialize(serialized_exported_program, artifact.state_dict, artifact.constants, artifact.example_inputs); }); this.registerType('torch._export.serde.serialize.ExportedProgramDeserializer', class { constructor(expected_opset_version) { this.expected_opset_version = expected_opset_version; } - deserialize(serialized_artifact) { - const symbol_name_to_range = new Map(Object.entries(serialized_artifact.exported_program.range_constraints)); + deserialize(exported_program, state_dict, constants, example_inputs) { + const symbol_name_to_range = new Map(Object.entries(exported_program.range_constraints)); /* symbol_name_to_range = { k: symbolic_shapes.ValueRanges(_int_to_sympy_int(v.min_val), _int_to_sympy_int(v.max_val)) - for k, v in serialized_artifact.exported_program.range_constraints.items() + for k, v in exported_program.range_constraints.items() } */ - const constants = serialized_artifact.constants ? torch.load(serialized_artifact.constants) : null; - const tensor_constants = constants ? new Map(Object.entries(constants).filter(([, tensor]) => tensor instanceof torch.Tensor)) : null; const deserializer = new torch._export.serde.serialize.GraphModuleDeserializer(); - const res = deserializer.deserialize(serialized_artifact.exported_program.graph_module, symbol_name_to_range, constants); + const res = deserializer.deserialize( + exported_program.graph_module, + state_dict, + constants, + example_inputs, + symbol_name_to_range); const range_constraints = null; /* range_constraints = self.deserialize_range_constraints( @@ -6061,16 +6108,21 @@ python.Execution = class { self._validate_model_opset_version(model_opset_version) upgrader = GraphModuleOpUpgrader(self.expected_opset_version, model_opset_version) */ - const state_dict = serialized_artifact.state_dict ? torch.load(serialized_artifact.state_dict) : null; - const exported_program = new torch.export.exported_program.ExportedProgram( + return new torch.export.exported_program.ExportedProgram( res.graph_module, res.graph_module.graph, res.signature, - state_dict, range_constraints, res.module_call_graph, null, + res.state_dict, range_constraints, res.module_call_graph, res.example_inputs, null, // verifier=load_verifier(serialized_artifact.exported_program.dialect), - tensor_constants); - return exported_program; + res.constants); // return upgrader.upgrade(exported_program) } }); + this.registerFunction('torch._export.serde.serialize.deserialize_torch_artifact', (serialized) => { + if (!serialized) { + return new builtins.dict(); + } + const artifact = torch.load(serialized); + return artifact; + }); this.registerType('torch._export.serde.serialize.GraphModuleDeserializer', class { constructor() { this.serialized_name_to_node = new Map(); @@ -6179,7 +6231,7 @@ python.Execution = class { Object.assign(fx_node.meta, this.deserialize_metadata(serialized_node.metadata)); } } - deserialize(serialized_graph_module, symbol_name_to_range, constants) { + deserialize(serialized_graph_module, serialized_state_dict, constants, example_inputs, symbol_name_to_range) { this.shape_env = new torch.fx.experimental.symbolic_shapes.ShapeEnv(/* assume_static_by_default = True */); /* this.fake_tensor_mode = FakeTensorMode( @@ -6189,16 +6241,31 @@ python.Execution = class { ) */ this.symbol_name_to_symbol = new Map(); + this.constants = torch._export.serde.serialize.deserialize_torch_artifact(constants); + this.signature = null; // self.deserialize_signature(serialized_graph_module.signature) this.symbol_name_to_range = symbol_name_to_range || new Map(); - this.constants = constants || new Map(); + /* + if symbol_name_to_range: + for k, vr in symbol_name_to_range.items(): + lower = int(vr.lower) + if vr.upper >= 2: # max is >= 2, not sym bool range + lower = max(2, lower) + self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower), vr.upper) + */ + this.example_inputs = null; + if (example_inputs && example_inputs.length > 0) { + torch._export.serde.serialize.deserialize_torch_artifact(example_inputs); + } this.deserialize_graph(serialized_graph_module.graph); - const sig = null; // self.deserialize_signature(serialized_graph_module.signature) const module_call_graph = null; // self.deserialize_module_call_graph(serialized_graph_module.module_call_graph) return { graph_module: torch._export.exported_program._create_graph_module_for_export(this.module, this.graph), - signature: sig, + signature: this.signature, module_call_graph, - names_to_symbols: this.symbol_name_to_symbol + names_to_symbols: this.symbol_name_to_symbol, + state_dict: torch._export.serde.serialize.deserialize_torch_artifact(serialized_state_dict), + constants: this.constants, + example_inputs: this.example_inputs, }; } sync_fx_node(name, fx_node) { @@ -6380,7 +6447,7 @@ python.Execution = class { */ } const hint = s.as_expr.hint || null; - if (hint && hint.$type === 'as_int') { + if (hint && (hint.$type === 'as_int' || hint.as_int !== undefined)) { return this.deserialize_sym_int(hint); } return this.shape_env.create_symintnode(sym, hint); diff --git a/source/pytorch.js b/source/pytorch.js index 4c72685f62..aff89c4abe 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -1132,10 +1132,12 @@ pytorch.Container.ExportedProgram = class extends pytorch.Container { this.format = 'PyTorch Export'; const serialized_state_dict = await this._fetch('serialized_state_dict.pt') || await this._fetch('serialized_state_dict.json'); const serialized_constants = await this._fetch('serialized_constants.pt') || await this._fetch('serialized_constants.json'); + const serialized_example_inputs = await this._fetch('serialized_example_inputs.pt'); const f = new Map(); f.set('serialized_exported_program.json', this.serialized_exported_program); f.set('serialized_state_dict.pt', serialized_state_dict); f.set('serialized_constants.pt', serialized_constants); + f.set('serialized_example_inputs.pt', serialized_example_inputs); const execution = new pytorch.Execution(); for (const event of this._events) { execution.on(event[0], event[1]); @@ -1157,13 +1159,15 @@ pytorch.Container.ExportedProgram = class extends pytorch.Container { } delete this.serialized_exported_program; delete this.context; - /* const exported_program = */ torch._export.load(f); + /* eslint-disable no-unused-vars */ + const exported_program = torch._export.load(f); + /* eslint-enable no-unused-vars */ throw new pytorch.Error(`'torch.export' not supported.`); } async _fetch(name) { try { - const context = await this._context.fetch(name); + const context = await this.context.fetch(name); if (context) { return context.peek('zip'); } @@ -3292,7 +3296,14 @@ pytorch.Utility = class { case 'torch.cuda': return obj.__class__.__name__.endsWith('Tensor') ? obj : null; case 'torch.nn.parameter': - return obj.__class__.__name__ === 'Parameter' ? obj.data : null; + if (obj.__class__.__name__ === 'Parameter') { + const data = obj.data; + if (typeof obj.__name__ === 'string') { + data.__name__ = obj.__name__; + } + return data; + } + return null; default: return null; }