From 21f9188ed57c6b62714e87da066d8386db689e79 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sun, 23 Jul 2023 02:06:32 -0700 Subject: [PATCH] Update onnx.js --- source/onnx.js | 93 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 56 insertions(+), 37 deletions(-) diff --git a/source/onnx.js b/source/onnx.js index aa69b136fc..6b70d65470 100644 --- a/source/onnx.js +++ b/source/onnx.js @@ -164,29 +164,33 @@ onnx.ModelFactory = class { async open(context, target) { const open = async (model, format) => { const metadata = await onnx.Metadata.open(context); - const graphs = new Set(); - const queue = [ model.graph ]; const locations = new Set(); - const tensor = (value) => { - if ((onnx.proto && value instanceof onnx.proto.SparseTensorProto) || - (onnx.schema && value instanceof onnx.schema.SparseTensor)) { - tensor(value.indices); - tensor(value.indices); - } else if (value.data_location === onnx.DataLocation.EXTERNAL && Array.isArray(value.external_data)) { - for (const entry of value.external_data) { + const location = (tensor) => { + if ((onnx.proto && tensor instanceof onnx.proto.SparseTensorProto) || + (onnx.schema && tensor instanceof onnx.schema.SparseTensor)) { + location(tensor.indices); + location(tensor.indices); + } else if (tensor.data_location === onnx.DataLocation.EXTERNAL && Array.isArray(tensor.external_data)) { + for (const entry of tensor.external_data) { if (entry.key === 'location') { locations.add(entry.value); } } } }; + const graphs = new Set(); + const queue = [ model.graph ]; while (queue.length > 0) { const graph = queue.shift(); - for (const initializer of graph.initializer) { - tensor(initializer); + if (Array.isArray(graph.initializer)) { + for (const initializer of graph.initializer) { + location(initializer); + } } - for (const sparse_initializer of graph.sparse_initializer) { - tensor(sparse_initializer); + if (Array.isArray(graph.sparse_initializer)) { + for (const sparse_initializer of graph.sparse_initializer) { + location(sparse_initializer); + } } if (Array.isArray(graph.node)) { for (const node of graph.node) { @@ -194,16 +198,22 @@ onnx.ModelFactory = class { for (const attribute of node.attribute) { if (attribute.g) { queue.push(attribute.g); - } else if (Array.isArray(attribute.graphs) && attribute.graphs.length > 0) { - queue.push(...attribute.graphs); } else if (attribute.t) { - tensor(attribute.t); - } else if (Array.isArray(attribute.tensors) && attribute.tensors.length > 0) { - attribute.tensors.every((value) => tensor(value)); + location(attribute.t); } else if (attribute.sparse_tensor) { - tensor(attribute.sparse_tensor); + location(attribute.sparse_tensor); + } else if (Array.isArray(attribute.graphs) && attribute.graphs.length > 0) { + for (const graph of attribute.graphs) { + queue.push(graph); + } + } else if (Array.isArray(attribute.tensors) && attribute.tensors.length > 0) { + for (const tensor of attribute.tensors) { + location(tensor); + } } else if (Array.isArray(attribute.sparse_tensors) && attribute.sparse_tensors.length > 0) { - attribute.sparse_tensors.every((value) => tensor(value)); + for (const tensor of attribute.sparse_tensors) { + location(tensor); + } } } } @@ -211,9 +221,16 @@ onnx.ModelFactory = class { } graphs.add(graph); } - const keys = Array.from(locations); - const streams = await Promise.all(keys.map((location) => context.request(location, null))); - const weights = new Map(keys.map((key, index) => [ key, streams[index] ])); + const weights = new Map(); + try { + const keys = Array.from(locations); + const streams = await Promise.all(keys.map((location) => context.request(location, null))); + for (let i = 0; i < keys.length; i++) { + weights.set(keys[i], streams[i]); + } + } catch (error) { + // continue regardless of error + } return new onnx.Model(metadata, format, model, Array.from(graphs), weights); }; switch (target) { @@ -240,17 +257,17 @@ onnx.ModelFactory = class { const reader = protobuf.BinaryReader.open(stream); const tensor = onnx.proto.TensorProto.decode(reader); tensor.name = tensor.name || context.identifier; + const attribute = new onnx.proto.AttributeProto(); + attribute.name = 'value'; + attribute.type = onnx.AttributeType.TENSOR; + attribute.t = tensor; + const node = new onnx.proto.NodeProto(); + node.op_type = 'Constant'; + node.attribute = [ attribute ]; + const graph = new onnx.proto.GraphProto(); + graph.node = [ node ]; const model = new onnx.proto.ModelProto(); - model.graph = new onnx.proto.GraphProto(); - model.graph.initializer = [ tensor ]; - model.graph.value_info = [ new onnx.proto.ValueInfoProto() ]; - model.graph.value_info[0].name = tensor.name; - model.graph.node = [ new onnx.proto.NodeProto() ]; - model.graph.node[0].op_type = 'Constant'; - model.graph.node[0].attribute = [ new onnx.proto.AttributeProto() ]; - model.graph.node[0].attribute[0].name = 'value'; - model.graph.node[0].attribute[0].type = onnx.AttributeType.TENSOR; - model.graph.node[0].attribute[0].t = tensor; + model.graph = graph; const format = 'ONNX Tensor'; return open(model, format); } catch (error) { @@ -1996,11 +2013,13 @@ onnx.Reader.text = class { this._expect('='); if (this._match('[')) { const list = []; - do { - list.push(this._literal()); + if (!this._match(']')) { + do { + list.push(this._literal()); + } + while (this._match(',')); + this._expect(']'); } - while (this._match(',')); - this._expect(']'); if (list.every((value) => typeof value === 'string')) { attribute.type = onnx.AttributeType.STRINGS; attribute.strings = list;