Skip to content

Commit

Permalink
Update onnx.js
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jul 22, 2023
1 parent a5b4652 commit 4fda491
Showing 1 changed file with 201 additions and 112 deletions.
313 changes: 201 additions & 112 deletions source/onnx.js
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,57 @@ onnx.ModelFactory = class {
async open(context, target) {
const open = async (model, format) => {
const metadata = await onnx.Metadata.open(context);
return new onnx.Model(metadata, model, format);
const graphs = new Set();
const queue = [ model.graph ];
const locations = new Set();
const tensor = (value) => {
if ((onnx.proto && value instanceof onnx.proto.SparseTensorProto) ||
(onnx.schema && value instanceof onnx.schema.SparseTensor)) {
tensor(value.indices);
tensor(value.indices);
} else if (value.data_location === onnx.DataLocation.EXTERNAL && Array.isArray(value.external_data)) {
for (const entry of value.external_data) {
if (entry.key === 'location') {
locations.add(entry.value);
}
}
}
};
while (queue.length > 0) {
const graph = queue.shift();
for (const initializer of graph.initializer) {
tensor(initializer);
}
for (const sparse_initializer of graph.sparse_initializer) {
tensor(sparse_initializer);
}
if (Array.isArray(graph.node)) {
for (const node of graph.node) {
if (Array.isArray(node.attribute)) {
for (const attribute of node.attribute) {
if (attribute.g) {
queue.push(attribute.g);
} else if (Array.isArray(attribute.graphs) && attribute.graphs.length > 0) {
queue.push(...attribute.graphs);
} else if (attribute.t) {
tensor(attribute.t);
} else if (Array.isArray(attribute.tensors) && attribute.tensors.length > 0) {
attribute.tensors.every((value) => tensor(value));
} else if (attribute.sparse_tensor) {
tensor(attribute.sparse_tensor);
} else if (Array.isArray(attribute.sparse_tensors) && attribute.sparse_tensors.length > 0) {
attribute.sparse_tensors.every((value) => tensor(value));
}
}
}
}
}
graphs.add(graph);
}
const keys = Array.from(locations);
const streams = await Promise.all(keys.map((location) => context.request(location, null)));
const weights = new Map(keys.map((key, index) => [ key, streams[index] ]));
return new onnx.Model(metadata, format, model, Array.from(graphs), weights);
};
switch (target) {
case 'onnx.pbtxt.ModelProto':
Expand Down Expand Up @@ -276,7 +326,7 @@ onnx.ModelFactory = class {

onnx.Model = class {

constructor(metadata, model, format) {
constructor(metadata, format, model, graphs, locations) {
this._graphs = [];
this._format = format;
this._producer = model.producer_name && model.producer_name.length > 0 ? model.producer_name + (model.producer_version && model.producer_version.length > 0 ? ' ' + model.producer_version : '') : null;
Expand All @@ -286,7 +336,6 @@ onnx.Model = class {
this._description = model.doc_string;
this._metadata = [];
this._imports = null;

const imports = new Map();
if (model.opset_import && model.opset_import.length > 0) {
for (const opset_import of model.opset_import) {
Expand All @@ -302,7 +351,6 @@ onnx.Model = class {
imports.set('ai.onnx', 1);
imports.set('ai.onnx.ml', 1);
}

let imageFormat = '';
const metadata_props = model.metadata_props;
if (metadata_props) {
Expand Down Expand Up @@ -347,27 +395,14 @@ onnx.Model = class {
}
imageFormat = [ imageMetadata['Image.BitmapPixelFormat'], imageMetadata['Image.ColorSpaceGamma'], imageMetadata['Image.NominalPixelRange'] ].filter((item) => item);
}
metadata = new onnx.GraphMetadata(metadata, imports);
const context = new onnx.ModelContext(metadata, locations, imageFormat);
for (const func of model.functions || []) {
context.metadata.add(new onnx.Function(context, func));
}
this._graphs = [];
if (model && model.graph) {
const graphMetadata = new onnx.GraphMetadata(metadata, imports);
const context = new onnx.ModelContext(graphMetadata, imageFormat);
for (const func of model.functions || []) {
context.metadata.add(new onnx.Function(context, func));
}
const graphs = [ model.graph ];
while (graphs.length > 0) {
const graph = graphs.shift();
this._graphs.push(context.graph(graph));
for (const node of graph.node || []) {
for (const attribute of node.attribute || []) {
if (attribute.g) {
graphs.push(attribute.g);
} else if (attribute.graphs && attribute.graphs.length > 0) {
graphs.push(...attribute.graphs);
}
}
}
}
for (const graph of graphs) {
this._graphs.push(context.graph(graph));
}
}

Expand Down Expand Up @@ -788,104 +823,127 @@ onnx.Tensor = class {
(onnx.schema && tensor instanceof onnx.schema.SparseTensor)) {
this._name = tensor.values.name || '';
this._type = context.createTensorType(tensor.values.data_type, tensor.dims.map((dim) => dim), null);
this._location = Array.from(new Set([ context.createLocation(tensor.values.data_location), context.createLocation(tensor.indices.data_location) ])).join(':');
this._location = context.createLocation(tensor.values.data_location);
this._layout = 'sparse';
this._values = new onnx.Tensor(context, tensor.values);
this._indices = new onnx.Tensor(context, tensor.indices);
} else {
this._name = tensor.name || '';
this._type = context.createTensorType(tensor.data_type, tensor.dims.map((dim) => dim), null);
this._location = context.createLocation(tensor.data_location);
if (tensor.data_location === onnx.DataLocation.DEFAULT) {
switch (tensor.data_type) {
case onnx.DataType.UNDEFINED: {
break;
}
case onnx.DataType.FLOAT:
this._data = new Float32Array(tensor.float_data);
this._layout = '|';
break;
case onnx.DataType.DOUBLE:
this._data = new Float64Array(tensor.double_data);
this._layout = '|';
break;
case onnx.DataType.BOOL:
if (tensor.int32_data && tensor.int32_data.length > 0) {
const array = tensor.int32_data;
this._data = new Array(array.length);
for (let i = 0; i < this._data.length; i++) {
this._data[i] = array[i] === 0 ? false : true;
switch (tensor.data_location) {
case onnx.DataLocation.DEFAULT: {
switch (tensor.data_type) {
case onnx.DataType.UNDEFINED: {
break;
}
case onnx.DataType.FLOAT:
this._data = new Float32Array(tensor.float_data);
this._layout = '|';
break;
case onnx.DataType.DOUBLE:
this._data = new Float64Array(tensor.double_data);
this._layout = '|';
break;
case onnx.DataType.BOOL:
if (tensor.int32_data && tensor.int32_data.length > 0) {
const array = tensor.int32_data;
this._data = new Array(array.length);
for (let i = 0; i < this._data.length; i++) {
this._data[i] = array[i] === 0 ? false : true;
}
this._layout = '|';
}
break;
case onnx.DataType.INT8:
this._data = new Int8Array(tensor.int32_data);
this._layout = '|';
}
break;
case onnx.DataType.INT8:
this._data = new Int8Array(tensor.int32_data);
this._layout = '|';
break;
case onnx.DataType.UINT8:
this._data = new Uint8Array(tensor.int32_data);
this._layout = '|';
break;
case onnx.DataType.INT16:
this._data = new Int32Array(tensor.int32_data);
this._layout = '|';
break;
case onnx.DataType.UINT16:
this._data = new Int32Array(tensor.int32_data);
this._layout = '|';
break;
case onnx.DataType.INT32:
this._data = new Int32Array(tensor.int32_data);
this._layout = '|';
break;
case onnx.DataType.UINT32:
case onnx.DataType.UINT64:
this._data = tensor.uint64_data;
this._layout = '|';
break;
case onnx.DataType.INT64:
this._data = tensor.int64_data;
this._layout = '|';
break;
case onnx.DataType.STRING:
this._data = tensor.string_data;
this._layout = '|';
break;
case onnx.DataType.COMPLEX64:
case onnx.DataType.COMPLEX128:
break;
case onnx.DataType.FLOAT16:
case onnx.DataType.BFLOAT16:
if (tensor.int32_data && tensor.int32_data.length > 0) {
const array = tensor.int32_data;
const buffer = new Uint8Array(array.length << 1);
const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
for (let i = 0; i < array.length; i++) {
view.setUint16(i << 1, array[i], true);
break;
case onnx.DataType.UINT8:
this._data = new Uint8Array(tensor.int32_data);
this._layout = '|';
break;
case onnx.DataType.INT16:
this._data = new Int32Array(tensor.int32_data);
this._layout = '|';
break;
case onnx.DataType.UINT16:
this._data = new Int32Array(tensor.int32_data);
this._layout = '|';
break;
case onnx.DataType.INT32:
this._data = new Int32Array(tensor.int32_data);
this._layout = '|';
break;
case onnx.DataType.UINT32:
case onnx.DataType.UINT64:
this._data = tensor.uint64_data;
this._layout = '|';
break;
case onnx.DataType.INT64:
this._data = tensor.int64_data;
this._layout = '|';
break;
case onnx.DataType.STRING:
this._data = tensor.string_data;
this._layout = '|';
break;
case onnx.DataType.COMPLEX64:
case onnx.DataType.COMPLEX128:
break;
case onnx.DataType.FLOAT16:
case onnx.DataType.BFLOAT16:
if (tensor.int32_data && tensor.int32_data.length > 0) {
const array = tensor.int32_data;
const buffer = new Uint8Array(array.length << 1);
const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
for (let i = 0; i < array.length; i++) {
view.setUint16(i << 1, array[i], true);
}
this._data = buffer;
this._layout = '<';
}
break;
case onnx.DataType.FLOAT8E4M3FN:
case onnx.DataType.FLOAT8E4M3FNUZ:
case onnx.DataType.FLOAT8E5M2:
case onnx.DataType.FLOAT8E5M2FNUZ:
if (tensor.int32_data && tensor.int32_data.length > 0) {
this._data = new Uint8Array(Array.from(tensor.int32_data));
this._layout = '<';
}
this._data = buffer;
this._layout = '<';
break;
default:
throw new onnx.Error("Unsupported tensor data type '" + tensor.data_type + "'.");
}
if (this._data && (Array.isArray(this._data) || ArrayBuffer.isView(this._data)) && this._data.length === 0) {
this._data = undefined;
}
if (!this._data && tensor.raw_data && tensor.raw_data.length > 0) {
this._data = tensor.raw_data;
this._layout = '<';
}
break;
}
case onnx.DataLocation.EXTERNAL: {
if (Array.isArray(tensor.external_data)) {
const external_data = {};
for (const entry of tensor.external_data) {
external_data[entry.key] = entry.value;
}
break;
case onnx.DataType.FLOAT8E4M3FN:
case onnx.DataType.FLOAT8E4M3FNUZ:
case onnx.DataType.FLOAT8E5M2:
case onnx.DataType.FLOAT8E5M2FNUZ:
if (tensor.int32_data && tensor.int32_data.length > 0) {
this._data = new Uint8Array(Array.from(tensor.int32_data));
this._layout = '<';
if (external_data.location && external_data.offset && external_data.length) {
const offset = parseInt(external_data.offset, 10);
const length = parseInt(external_data.length, 10);
if (Number.isInteger(offset) && Number.isInteger(length)) {
this._data = context.location(external_data.location, offset, length);
this._layout = '<';
}
}
break;
default:
throw new onnx.Error("Unsupported tensor data type '" + tensor.data_type + "'.");
}
if (this._data && (Array.isArray(this._data) || ArrayBuffer.isView(this._data)) && this._data.length === 0) {
this._data = undefined;
}
break;
}
if (!this._data && tensor.raw_data && tensor.raw_data.length > 0) {
this._data = tensor.raw_data;
this._layout = '<';
default: {
throw new Error();
}
}
}
Expand All @@ -912,9 +970,21 @@ onnx.Tensor = class {
}

get values() {
return this._layout === 'sparse' ? this._values : this._data;
switch (this._layout) {
case 'sparse': {
return this._values;
}
default: {
if (!this._data || this._data instanceof Uint8Array) {
return this._data;
}
if (Array.isArray(this._data) || ArrayBuffer.isView(this._data)) {
return this._data;
}
return this._data.peek();
}
}
}

};

onnx.TensorType = class {
Expand Down Expand Up @@ -1289,8 +1359,9 @@ onnx.AttributeType = {

onnx.ModelContext = class {

constructor(metadata, imageFormat) {
constructor(metadata, locations, imageFormat) {
this._metadata = metadata;
this._locations = locations;
this._imageFormat = imageFormat;
this._graphs = new Map();
}
Expand All @@ -1303,6 +1374,20 @@ onnx.ModelContext = class {
return this._imageFormat;
}

location(name, offset, length) {
if (this._locations.has(name)) {
const stream = this._locations.get(name);
if (offset < stream.length && (offset + length) < stream.length) {
const position = stream.position;
stream.seek(offset);
const value = stream.stream(length);
stream.seek(position);
return value;
}
}
return this._locations;
}

graph(value) {
if (!this._graphs.has(value)) {
this._graphs.set(value, new onnx.Graph(this, value));
Expand Down Expand Up @@ -1374,6 +1459,10 @@ onnx.GraphContext = class {
return this._tensors.get(name);
}

location(name, offset, length) {
return this._context.location(name, offset, length);
}

group(name) {
if (!this._groups.has(name)) {
const path = name.split('/');
Expand Down

0 comments on commit 4fda491

Please sign in to comment.