Skip to content

Commit

Permalink
Update python.js (#1211)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Sep 26, 2024
1 parent b68868b commit 89ad09b
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 27 deletions.
115 changes: 91 additions & 24 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -4865,13 +4865,15 @@ python.Execution = class {
return chars.join('');
}
});
torch.fx.Graph = torch.fx.graph.Graph;
this.registerType('torch.fx.graph_module.GraphModule', class extends torch.nn.modules.module.Module {
constructor(root, graph) {
constructor(root, graph, class_name) {
super();
this.__class__.__name__ = class_name || 'GraphModule';
this.graph = graph;
}
});
torch.fx.Graph = torch.fx.graph.Graph;
torch.fx.GraphModule = torch.fx.graph_module.GraphModule;
this.registerType('torch.fx.immutable_collections.immutable_dict', class extends builtins.dict {});
this.registerFunction('torch.fx._symbolic_trace.wrap', (fn_or_name) => {
return fn_or_name;
Expand Down Expand Up @@ -5988,25 +5990,67 @@ python.Execution = class {
this.registerType('torch.export.graph_signature.OutputKind', class extends this._enum.Enum {});
this.registerType('torch.export.graph_signature.OutputSpec', class extends this._enum.Enum {});
this.registerType('torch.export.graph_signature.TensorArgument', class {});
this.registerType('torch.export.exported_program.ExportedProgram', class {});
this.registerType('torch.export.exported_program.ExportedProgram', class {
constructor(root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs, verifier, tensor_constants, constants) {
// graph._codegen = torch.fx.graph.CodeGen()
this._graph_module = this._create_graph_module_for_export(root, graph);
if (root instanceof torch.fx.GraphModule) {
// this._graph_module.meta.update(root.meta);
}
this._graph_signature = graph_signature;
this._state_dict = state_dict;
this._range_constraints = range_constraints;
this._module_call_graph = module_call_graph;
this._example_inputs = example_inputs;
this._constants = tensor_constants || constants || {};

/*
graph: torch.fx.Graph,
graph_signature: ExportGraphSignature,
state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]],
range_constraints: "Dict[sympy.Symbol, Any]",
module_call_graph: List[ModuleCallEntry],
example_inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None,
constants: Optional[
Dict[str, Union[torch.Tensor, FakeScriptObject, torch._C.ScriptObject]]
] = None,
*,
verifiers: Optional[List[Type[Verifier]]] = None,
*/
}
_create_graph_module_for_export(root, graph) {
let gm = null;
try {
gm = new torch.fx.GraphModule(root, graph);
} catch {
const gm = new torch.fx.GraphModule(root, torch.fx.Graph());
gm._graph = graph;
}
return gm;
}
});
this.registerType('torch.export.exported_program.ModuleCallEntry', class {});
this.registerType('torch.export.exported_program.ModuleCallSignature', class {});
this.registerFunction('torch.export.unflatten');
this.registerFunction('torch._export.exported_program._create_graph_module_for_export', (root, graph) => {
return new torch.fx.graph_module.GraphModule(root, graph);
});
this.registerType('torch._export.serde.serialize.SerializedArtifact', class {
constructor(exported_program, state_dict, constants) {
constructor(exported_program, state_dict, constants, example_inputs) {
this.exported_program = exported_program;
this.state_dict = state_dict;
this.constants = constants;
this.example_inputs = example_inputs;
}
});
this.registerType('torch._export.serde.schema.SymInt', class {
});
this.registerFunction('torch._export.load', (f, expected_opset_version) => {
const serialized_exported_program = f.get('serialized_exported_program.json');
const serialized_state_dict = f.get('serialized_state_dict.pt');
const serialized_constants = f.get('serialized_constants.pt');
const artifact = new torch._export.serde.serialize.SerializedArtifact(serialized_exported_program, serialized_state_dict, serialized_constants);
const serialized_example_inputs = f.get('serialized_example_inputs.pt');
const artifact = new torch._export.serde.serialize.SerializedArtifact(serialized_exported_program, serialized_state_dict, serialized_constants, serialized_example_inputs);
return torch._export.serde.serialize.deserialize(artifact, expected_opset_version);
});
this.registerFunction('torch._export.serde.serialize._dict_to_dataclass', (cls, data) => {
Expand All @@ -6033,25 +6077,28 @@ python.Execution = class {
return data;
});
this.registerFunction('torch._export.serde.serialize.deserialize', (artifact, expected_opset_version) => {
artifact.exported_program = torch._export.serde.serialize._dict_to_dataclass(null, artifact.exported_program);
return new torch._export.serde.serialize.ExportedProgramDeserializer(expected_opset_version).deserialize(artifact);
const serialized_exported_program = torch._export.serde.serialize._dict_to_dataclass(null, artifact.exported_program);
return new torch._export.serde.serialize.ExportedProgramDeserializer(expected_opset_version).deserialize(serialized_exported_program, artifact.state_dict, artifact.constants, artifact.example_inputs);
});
this.registerType('torch._export.serde.serialize.ExportedProgramDeserializer', class {
constructor(expected_opset_version) {
this.expected_opset_version = expected_opset_version;
}
deserialize(serialized_artifact) {
const symbol_name_to_range = new Map(Object.entries(serialized_artifact.exported_program.range_constraints));
deserialize(exported_program, state_dict, constants, example_inputs) {
const symbol_name_to_range = new Map(Object.entries(exported_program.range_constraints));
/*
symbol_name_to_range = {
k: symbolic_shapes.ValueRanges(_int_to_sympy_int(v.min_val), _int_to_sympy_int(v.max_val))
for k, v in serialized_artifact.exported_program.range_constraints.items()
for k, v in exported_program.range_constraints.items()
}
*/
const constants = serialized_artifact.constants ? torch.load(serialized_artifact.constants) : null;
const tensor_constants = constants ? new Map(Object.entries(constants).filter(([, tensor]) => tensor instanceof torch.Tensor)) : null;
const deserializer = new torch._export.serde.serialize.GraphModuleDeserializer();
const res = deserializer.deserialize(serialized_artifact.exported_program.graph_module, symbol_name_to_range, constants);
const res = deserializer.deserialize(
exported_program.graph_module,
state_dict,
constants,
example_inputs,
symbol_name_to_range);
const range_constraints = null;
/*
range_constraints = self.deserialize_range_constraints(
Expand All @@ -6061,16 +6108,21 @@ python.Execution = class {
self._validate_model_opset_version(model_opset_version)
upgrader = GraphModuleOpUpgrader(self.expected_opset_version, model_opset_version)
*/
const state_dict = serialized_artifact.state_dict ? torch.load(serialized_artifact.state_dict) : null;
const exported_program = new torch.export.exported_program.ExportedProgram(
return new torch.export.exported_program.ExportedProgram(
res.graph_module, res.graph_module.graph, res.signature,
state_dict, range_constraints, res.module_call_graph, null,
res.state_dict, range_constraints, res.module_call_graph, res.example_inputs,
null, // verifier=load_verifier(serialized_artifact.exported_program.dialect),
tensor_constants);
return exported_program;
res.constants);
// return upgrader.upgrade(exported_program)
}
});
this.registerFunction('torch._export.serde.serialize.deserialize_torch_artifact', (serialized) => {
if (!serialized) {
return new builtins.dict();
}
const artifact = torch.load(serialized);
return artifact;
});
this.registerType('torch._export.serde.serialize.GraphModuleDeserializer', class {
constructor() {
this.serialized_name_to_node = new Map();
Expand Down Expand Up @@ -6179,7 +6231,7 @@ python.Execution = class {
Object.assign(fx_node.meta, this.deserialize_metadata(serialized_node.metadata));
}
}
deserialize(serialized_graph_module, symbol_name_to_range, constants) {
deserialize(serialized_graph_module, serialized_state_dict, constants, example_inputs, symbol_name_to_range) {
this.shape_env = new torch.fx.experimental.symbolic_shapes.ShapeEnv(/* assume_static_by_default = True */);
/*
this.fake_tensor_mode = FakeTensorMode(
Expand All @@ -6189,16 +6241,31 @@ python.Execution = class {
)
*/
this.symbol_name_to_symbol = new Map();
this.constants = torch._export.serde.serialize.deserialize_torch_artifact(constants);
this.signature = null; // self.deserialize_signature(serialized_graph_module.signature)
this.symbol_name_to_range = symbol_name_to_range || new Map();
this.constants = constants || new Map();
/*
if symbol_name_to_range:
for k, vr in symbol_name_to_range.items():
lower = int(vr.lower)
if vr.upper >= 2: # max is >= 2, not sym bool range
lower = max(2, lower)
self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower), vr.upper)
*/
this.example_inputs = null;
if (example_inputs && example_inputs.length > 0) {
torch._export.serde.serialize.deserialize_torch_artifact(example_inputs);
}
this.deserialize_graph(serialized_graph_module.graph);
const sig = null; // self.deserialize_signature(serialized_graph_module.signature)
const module_call_graph = null; // self.deserialize_module_call_graph(serialized_graph_module.module_call_graph)
return {
graph_module: torch._export.exported_program._create_graph_module_for_export(this.module, this.graph),
signature: sig,
signature: this.signature,
module_call_graph,
names_to_symbols: this.symbol_name_to_symbol
names_to_symbols: this.symbol_name_to_symbol,
state_dict: torch._export.serde.serialize.deserialize_torch_artifact(serialized_state_dict),
constants: this.constants,
example_inputs: this.example_inputs,
};
}
sync_fx_node(name, fx_node) {
Expand Down Expand Up @@ -6380,7 +6447,7 @@ python.Execution = class {
*/
}
const hint = s.as_expr.hint || null;
if (hint && hint.$type === 'as_int') {
if (hint && (hint.$type === 'as_int' || hint.as_int !== undefined)) {
return this.deserialize_sym_int(hint);
}
return this.shape_env.create_symintnode(sym, hint);
Expand Down
17 changes: 14 additions & 3 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -1132,10 +1132,12 @@ pytorch.Container.ExportedProgram = class extends pytorch.Container {
this.format = 'PyTorch Export';
const serialized_state_dict = await this._fetch('serialized_state_dict.pt') || await this._fetch('serialized_state_dict.json');
const serialized_constants = await this._fetch('serialized_constants.pt') || await this._fetch('serialized_constants.json');
const serialized_example_inputs = await this._fetch('serialized_example_inputs.pt');
const f = new Map();
f.set('serialized_exported_program.json', this.serialized_exported_program);
f.set('serialized_state_dict.pt', serialized_state_dict);
f.set('serialized_constants.pt', serialized_constants);
f.set('serialized_example_inputs.pt', serialized_example_inputs);
const execution = new pytorch.Execution();
for (const event of this._events) {
execution.on(event[0], event[1]);
Expand All @@ -1157,13 +1159,15 @@ pytorch.Container.ExportedProgram = class extends pytorch.Container {
}
delete this.serialized_exported_program;
delete this.context;
/* const exported_program = */ torch._export.load(f);
/* eslint-disable no-unused-vars */
const exported_program = torch._export.load(f);
/* eslint-enable no-unused-vars */
throw new pytorch.Error(`'torch.export' not supported.`);
}

async _fetch(name) {
try {
const context = await this._context.fetch(name);
const context = await this.context.fetch(name);
if (context) {
return context.peek('zip');
}
Expand Down Expand Up @@ -3292,7 +3296,14 @@ pytorch.Utility = class {
case 'torch.cuda':
return obj.__class__.__name__.endsWith('Tensor') ? obj : null;
case 'torch.nn.parameter':
return obj.__class__.__name__ === 'Parameter' ? obj.data : null;
if (obj.__class__.__name__ === 'Parameter') {
const data = obj.data;
if (typeof obj.__name__ === 'string') {
data.__name__ = obj.__name__;
}
return data;
}
return null;
default:
return null;
}
Expand Down

0 comments on commit 89ad09b

Please sign in to comment.