Skip to content

Commit

Permalink
Update pytorch.js (#637)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Oct 12, 2024
1 parent 9d225e8 commit 0cb3c3c
Showing 1 changed file with 40 additions and 1 deletion.
41 changes: 40 additions & 1 deletion source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -2261,6 +2261,45 @@ pytorch.jit.Execution = class extends pytorch.Execution {
return super.target(expression, context);
}

type(expression) {
const torch = this.torch;
if (expression.type === '[]' && expression.target.type === 'id') {
switch (expression.target.value) {
case 'List': {
const elementType = this.type(expression.arguments.value[0]);
return new torch.ListType(elementType);
}
case 'Optional': {
const elementType = this.type(expression.arguments.value[0]);
return new torch.OptionalType(elementType);
}
case 'Tuple': {
const args = expression.arguments.value.map((expression) => this.type(expression));
return new torch.TupleType(args);
}
case 'Dict': {
const key = this.type(expression.arguments.value[0]);
const value = this.type(expression.arguments.value[1]);
return new torch.DictType(key, value);
}
default: {
throw new pytorch.Error(`Unsupported type element expression '${expression.target.value}'.`);
}
}
}
if (expression.type === 'id') {
switch (expression.value) {
case 'Tensor': return new torch.TensorType();
case 'int': return new torch.IntType();
case 'str': return new torch.StringType();
case 'float': return new torch.FloatType();
case 'number': return new torch.NumberType();
default: throw new pytorch.Error(`Unsupported type expression '${expression.value}'.`);
}
}
throw new pytorch.Error(`Unsupported type expression '${expression.type}'.`);
}

call(target, name, args, context) {
if (this.trace) {
const overload = this._overload(target, name, args, context);
Expand Down Expand Up @@ -2650,7 +2689,7 @@ pytorch.jit.Execution = class extends pytorch.Execution {
case 'Scalar':
return (obj !== null && (obj !== Object(obj) || obj instanceof Number)) ||
(pytorch.Utility.isTensor(obj) && Array.isArray(obj.size()) && obj.size().length === 0) ||
(obj instanceof torch.Value && (obj.type() instanceof torch.IntType || obj.type() instanceof torch.FloatType));
(obj instanceof torch.Value && (obj.type() instanceof torch.IntType || obj.type() instanceof torch.FloatType || obj.type() instanceof torch.NumberType));
case 'boolean':
return obj === true || obj === false || (pytorch.Utility.isInstance(obj, 'torch.Value') && pytorch.Utility.isInstance(obj.type(), 'torch.BoolType'));
case 'boolean[]':
Expand Down

0 comments on commit 0cb3c3c

Please sign in to comment.