Skip to content

Commit

Permalink
Update coreml.js
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jul 19, 2023
1 parent af1f23b commit 84053c1
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions source/coreml.js
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,8 @@ coreml.Graph = class {
const value = (name, type, description, tensor) => {
if (!values.has(name)) {
values.set(name, new coreml.Value(name, type, description, tensor));
} else if ((type && !type.equals(values.get(name).type)) || description || tensor) {
} else if ((type && !type.equals(values.get(name).type)) ||
(tensor && tensor !== values.get(name).initializer) || description) {
throw new coreml.Error("Duplicate value '" + name + "'.");
}
return values.get(name);
Expand Down Expand Up @@ -1267,7 +1268,7 @@ coreml.TensorType = class {
}

equals(obj) {
return obj && this._dataType === obj.dataType && this._shape && this._shape.equals(obj.shape);
return obj && this.dataType === obj.dataType && this.shape && this.shape.equals(obj.shape);
}

toString() {
Expand All @@ -1278,17 +1279,17 @@ coreml.TensorType = class {
coreml.TensorShape = class {

constructor(dimensions) {
this._dimensions = dimensions;
this._dimensions = dimensions.map((dim) => typeof dim === 'string' || Number.isInteger(dim) ? dim : dim.toNumber());
}

get dimensions() {
return this._dimensions;
}

equals(obj) {
return obj && Array.isArray(obj.dimensions) &&
Array.isArray(this._dimensions) && this._dimensions.length === obj.dimensions.length
&& obj.dimensions.every((value, index) => this._dimensions[index] === value);
return obj && Array.isArray(obj.dimensions) && Array.isArray(this._dimensions) &&
this._dimensions.length === obj.dimensions.length &&
obj.dimensions.every((value, index) => this._dimensions[index] === value);
}

toString() {
Expand Down Expand Up @@ -1425,7 +1426,7 @@ coreml.Utility = class {
case 'multiArrayType': {
let shape = new coreml.TensorShape([]);
if (type.multiArrayType.shape && type.multiArrayType.shape.length > 0) {
shape = new coreml.TensorShape(type.multiArrayType.shape);
shape = new coreml.TensorShape(type.multiArrayType.shape.map((dim) => dim.toNumber()));
}
let dataType;
const ArrayDataType = coreml.proto.ArrayFeatureType.ArrayDataType;
Expand Down

0 comments on commit 84053c1

Please sign in to comment.