From 0cb3c3c368073cd5c1d5dbdda0040b14db9e559d Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Fri, 11 Oct 2024 23:38:39 -0700 Subject: [PATCH] Update pytorch.js (#637) --- source/pytorch.js | 41 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/source/pytorch.js b/source/pytorch.js index eb611d873a..f6db6b9d1d 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -2261,6 +2261,45 @@ pytorch.jit.Execution = class extends pytorch.Execution { return super.target(expression, context); } + type(expression) { + const torch = this.torch; + if (expression.type === '[]' && expression.target.type === 'id') { + switch (expression.target.value) { + case 'List': { + const elementType = this.type(expression.arguments.value[0]); + return new torch.ListType(elementType); + } + case 'Optional': { + const elementType = this.type(expression.arguments.value[0]); + return new torch.OptionalType(elementType); + } + case 'Tuple': { + const args = expression.arguments.value.map((expression) => this.type(expression)); + return new torch.TupleType(args); + } + case 'Dict': { + const key = this.type(expression.arguments.value[0]); + const value = this.type(expression.arguments.value[1]); + return new torch.DictType(key, value); + } + default: { + throw new pytorch.Error(`Unsupported type element expression '${expression.target.value}'.`); + } + } + } + if (expression.type === 'id') { + switch (expression.value) { + case 'Tensor': return new torch.TensorType(); + case 'int': return new torch.IntType(); + case 'str': return new torch.StringType(); + case 'float': return new torch.FloatType(); + case 'number': return new torch.NumberType(); + default: throw new pytorch.Error(`Unsupported type expression '${expression.value}'.`); + } + } + throw new pytorch.Error(`Unsupported type expression '${expression.type}'.`); + } + call(target, name, args, context) { if (this.trace) { const overload = this._overload(target, name, args, context); @@ -2650,7 +2689,7 @@ pytorch.jit.Execution = class extends pytorch.Execution { case 'Scalar': return (obj !== null && (obj !== Object(obj) || obj instanceof Number)) || (pytorch.Utility.isTensor(obj) && Array.isArray(obj.size()) && obj.size().length === 0) || - (obj instanceof torch.Value && (obj.type() instanceof torch.IntType || obj.type() instanceof torch.FloatType)); + (obj instanceof torch.Value && (obj.type() instanceof torch.IntType || obj.type() instanceof torch.FloatType || obj.type() instanceof torch.NumberType)); case 'boolean': return obj === true || obj === false || (pytorch.Utility.isInstance(obj, 'torch.Value') && pytorch.Utility.isInstance(obj.type(), 'torch.BoolType')); case 'boolean[]':