Skip to content

Commit

Permalink
Update pytorch.js (#637)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Oct 12, 2024
1 parent 14f49e0 commit f3e981f
Show file tree
Hide file tree
Showing 4 changed files with 471 additions and 112 deletions.
195 changes: 176 additions & 19 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -5632,6 +5632,9 @@ python.Execution = class {
return NaN;
});
this.registerFunction('torch.eq', (left, right) => {
const value = (x) => x && x.__class__ && x.__class__.__module__ === 'torch' && x.__class__.__name__ === 'Value' ? x.value : x;
left = value(left);
right = value(right);
if (typeof left === 'string' && typeof right === 'string') {
return left === right;
}
Expand Down Expand Up @@ -5677,6 +5680,9 @@ python.Execution = class {
return self.replace(regex, '');
});
this.registerFunction('torch.gt', (left, right) => {
const value = (x) => x && x.__class__ && x.__class__.__module__ === 'torch' && x.__class__.__name__ === 'Value' ? x.value : x;
left = value(left);
right = value(right);
if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) {
if (!isNaN(left) && !isNaN(right)) {
return left > right;
Expand Down Expand Up @@ -5998,6 +6004,9 @@ python.Execution = class {
throw new python.Error("Unsupported 'torch.remainder' expression type.");
});
this.registerFunction('torch.ne', (left, right) => {
const value = (x) => x && x.__class__ && x.__class__.__module__ === 'torch' && x.__class__.__name__ === 'Value' ? x.value : x;
left = value(left);
right = value(right);
if (typeof left === 'boolean' && typeof right === 'boolean') {
return left !== right;
}
Expand Down Expand Up @@ -6174,7 +6183,7 @@ python.Execution = class {
});
this.registerType('torch.OptionalType', class extends torch.Type {
constructor(elem) {
super();
super('OptionalType');
this._elem = elem;
}
getElementType() {
Expand All @@ -6183,35 +6192,85 @@ python.Execution = class {
});
this.registerType('torch.ListType', class extends torch.Type {
constructor(elem, size) {
super();
super('ListType');
this._elem = elem;
this._size = size;
if (size) {
this._size = size;
}
}
getElementType() {
return this._elem;
}
});
this.registerType('torch.FutureType', class extends torch.Type {
constructor(elem, size) {
super();
super('FutureType');
this._elem = elem;
this._size = size;
}
getElementType() {
return this._elem;
}
});
this.registerType('torch.TupleType', class extends torch.Type {});
this.registerType('torch.TensorType', class extends torch.Type {});
this.registerType('torch.TupleType', class extends torch.Type {
constructor() {
super('TupleType');
}
});
this.registerType('torch.TensorType', class extends torch.Type {
constructor() {
super('TensorType');
}
});
this.registerType('torch.AnyType', class extends torch.Type {});
this.registerType('torch.NumberType', class extends torch.Type {});
this.registerType('torch.BoolType', class extends torch.Type {});
this.registerType('torch.IntType', class extends torch.Type {});
this.registerType('torch.SymIntType', class extends torch.Type {});
this.registerType('torch.FloatType', class extends torch.Type {});
this.registerType('torch.StringType', class extends torch.Type {});
this.registerType('torch.ComplexType', class extends torch.Type {});
this.registerType('torch.DictType', class extends torch.Type {});
this.registerType('torch.NumberType', class extends torch.Type {
constructor() {
super('NumberType');
}
});
this.registerType('torch.BoolType', class extends torch.Type {
constructor() {
super('BoolType');
}
});
this.registerType('torch.IntType', class extends torch.Type {
constructor() {
super('IntType');
}
});
this.registerType('torch.SymIntType', class extends torch.Type {
constructor() {
super('SymIntType');
}
});
this.registerType('torch.FloatType', class extends torch.Type {
constructor() {
super('FloatType');
}
});
this.registerType('torch.StringType', class extends torch.Type {
constructor() {
super('StringType');
}
});
this.registerType('torch.ComplexType', class extends torch.Type {
constructor() {
super('ComplexType');
}
});
this.registerType('torch.DictType', class extends torch.Type {
constructor(key, value) {
super('DictType');
this._key = key;
this._value = value;
}
getKeyType() {
return this._key;
}
getValueType() {
return this._value;
}
});
this.registerType('torch.DeviceObjType', class extends torch.Type {});
this.registerType('torch._C._GeneratorType', class extends torch.Type {});
this.registerType('torch.Argument', class {
Expand All @@ -6229,13 +6288,106 @@ python.Execution = class {
has_default_value() {
return this.default_value !== undefined;
}
toString() {
const list = [];
if (this.name) {
list.push(this.name);
}
return list.join('');
}
});
this.registerType('torch.FunctionSchema', class {
constructor(name, overload_name, args, returns) {
this.arguments = args;
this.returns = returns;
this.name = name;
this.overload_name = overload_name;
constructor(name, overload_name, args, returns, is_vararg, is_varret) {
let index = name.indexOf('(');
if (index === -1) {
this._name = name;
this._overload_name = overload_name;
this._arguments = args;
this._returns = returns;
this._is_vararg = is_vararg;
this._is_varret = is_varret;
} else {
const value = name.substring(0, index).trim();
this._arguments = name.substring(index + 1, name.length);
index = value.indexOf('.');
if (index === -1) {
this._name = value;
this._overload_name = '';
} else {
this._name = value.substring(0, index);
this._overload_name = value.substring(index + 1, value.length);
}
}
}
static parse(schema) {
return new torch.FunctionSchema(schema);
}
get name() {
return this._name;
}
get overload_name() {
return this._overload_name;
}
get arguments() {
this._parse();
return this._arguments;
}
get returns() {
this._parse();
return this._returns;
}
get is_vararg() {
this._parse();
return this._is_vararg;
}
get is_varret() {
this._parse();
return this._is_varret;
}
_parse() {
if (!Array.isArray(this._arguments)) {
throw new python.Error(`'torch.FunctionSchema.parse' not implemented.`);
}
}
toString() {
const list = [this.name];
const overload_name = this.overload_name;
if (overload_name !== '' && overload_name !== 'default') {
list.push(`.${this.overload_name}`);
}
list.push('(');
let first = true;
for (const argument of this.arguments) {
if (!first) {
list.push(', ');
}
first = false;
list.push(argument.toString());
}
if (this.is_vararg) {
if (!first) {
list.push(', ');
}
first = true;
list.push('...');
}
list.push(') -> ');
const returns = this.returns;
if (returns.length > 1) {
list.push('(');
}
first = true;
for (const argument of this.returns) {
if (!first) {
list.push(', ');
}
first = false;
list.push(argument.toString());
}
if (returns.length > 1) {
list.push(')');
}
return list.join('');
}
});
this.registerType('torch._ops.OpOverload', class extends torch._ops.OperatorBase {
Expand Down Expand Up @@ -7899,6 +8051,11 @@ python.Execution = class {
this.registerType('torch.cuda.DoubleTensor', class extends torch.Tensor {});
this.registerType('torch.cuda.amp.grad_scaler.GradScaler', class {});
this.registerFunction('torch.cuda.amp.grad_scaler._refresh_per_optimizer_state');
this.registerType('torch.SymBool', class {
constructor(node) {
this.node = node;
}
});
this.registerType('torch.SymInt', class {
constructor(node) {
this.node = node;
Expand Down
93 changes: 93 additions & 0 deletions source/pytorch-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,18 @@
{ "type": "Tensor" }
]
},
{
"name": "aten::_cdist_forward",
"inputs": [
{ "name": "x1", "type": "Tensor" },
{ "name": "x2", "type": "Tensor" },
{ "name": "p", "type": "float32" },
{ "name": "compute_mode", "type": "int64", "optional": true }
],
"outputs": [
{ "type": "Tensor" }
]
},
{
"name": "aten::_coalesce",
"inputs": [
Expand Down Expand Up @@ -1035,6 +1047,41 @@
{ "name": "?", "type": "Tensor" }
]
},
{
"name": "aten::_scaled_dot_product_efficient_attention",
"inputs": [
{ "name": "query", "type": "Tensor" },
{ "name": "key", "type": "Tensor" },
{ "name": "value", "type": "Tensor" },
{ "name": "attn_bias", "type": "Tensor", "optional": true },
{ "name": "compute_log_sumexp", "type": "boolean" },
{ "name": "dropout_p", "type": "float32", "default": 0.0 },
{ "name": "is_causal", "type": "boolean", "default": false },
{ "name": "scale", "type": "float32", "optional": true, "default": null, "kwarg_only": true }
],
"outputs": [
{ "name": "output", "type": "Tensor" },
{ "name": "log_sumexp", "type": "Tensor" },
{ "name": "philox_seed", "type": "Tensor" },
{ "name": "philox_offset", "type": "Tensor" }
]
},
{
"name": "aten::_scaled_dot_product_flash_attention_for_cpu",
"inputs": [
{ "name": "query", "type": "Tensor" },
{ "name": "key", "type": "Tensor" },
{ "name": "value", "type": "Tensor" },
{ "name": "dropout_p", "type": "float32", "default": 0.0 },
{ "name": "is_causal", "type": "boolean", "default": false },
{ "name": "attn_mask", "type": "Tensor", "optional": true, "default": null, "kwarg_only": true },
{ "name": "scale", "type": "float32", "optional": true, "default": null, "kwarg_only": true }
],
"outputs": [
{ "name": "output", "type": "Tensor" },
{ "name": "logsumexp", "type": "Tensor" }
]
},
{
"name": "aten::_shape_as_tensor",
"inputs": [
Expand Down Expand Up @@ -1153,6 +1200,16 @@
{ "type": "Tensor" }
]
},
{
"name": "aten::_unsafe_view",
"inputs": [
{ "name": "self", "type": "Tensor" },
{ "name": "size", "type": "SymInt[]" }
],
"outputs": [
{ "type": "Tensor" }
]
},
{
"name": "aten::_weight_norm",
"inputs": [
Expand Down Expand Up @@ -13868,6 +13925,42 @@
{ "type": "Tensor" }
]
},
{
"name": "aten::tensor.bool",
"inputs": [
{ "name": "t", "type": "boolean" },
{ "name": "dtype", "type": "ScalarType", "optional": true, "default": null, "kwarg_only": true },
{ "name": "device", "type": "Device", "optional": true, "default": null, "kwarg_only": true },
{ "name": "requires_grad", "type": "boolean", "default": false, "kwarg_only": true }
],
"outputs": [
{ "type": "Tensor" }
]
},
{
"name": "aten::tensor.float",
"inputs": [
{ "name": "t", "type": "float32" },
{ "name": "dtype", "type": "ScalarType", "optional": true, "default": null, "kwarg_only": true },
{ "name": "device", "type": "Device", "optional": true, "default": null, "kwarg_only": true },
{ "name": "requires_grad", "type": "boolean", "default": false, "kwarg_only": true }
],
"outputs": [
{ "type": "Tensor" }
]
},
{
"name": "aten::tensor.int",
"inputs": [
{ "name": "t", "type": "int64" },
{ "name": "dtype", "type": "ScalarType", "optional": true, "default": null, "kwarg_only": true },
{ "name": "device", "type": "Device", "optional": true, "default": null, "kwarg_only": true },
{ "name": "requires_grad", "type": "boolean", "default": false, "kwarg_only": true }
],
"outputs": [
{ "type": "Tensor" }
]
},
{
"name": "aten::tensor_split.indices",
"inputs": [
Expand Down
Loading

0 comments on commit f3e981f

Please sign in to comment.