Skip to content

Commit

Permalink
Update onnx.js
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jul 23, 2023
1 parent 3153e0e commit 21f9188
Showing 1 changed file with 56 additions and 37 deletions.
93 changes: 56 additions & 37 deletions source/onnx.js
Original file line number Diff line number Diff line change
Expand Up @@ -164,56 +164,73 @@ onnx.ModelFactory = class {
async open(context, target) {
const open = async (model, format) => {
const metadata = await onnx.Metadata.open(context);
const graphs = new Set();
const queue = [ model.graph ];
const locations = new Set();
const tensor = (value) => {
if ((onnx.proto && value instanceof onnx.proto.SparseTensorProto) ||
(onnx.schema && value instanceof onnx.schema.SparseTensor)) {
tensor(value.indices);
tensor(value.indices);
} else if (value.data_location === onnx.DataLocation.EXTERNAL && Array.isArray(value.external_data)) {
for (const entry of value.external_data) {
const location = (tensor) => {
if ((onnx.proto && tensor instanceof onnx.proto.SparseTensorProto) ||
(onnx.schema && tensor instanceof onnx.schema.SparseTensor)) {
location(tensor.indices);
location(tensor.indices);
} else if (tensor.data_location === onnx.DataLocation.EXTERNAL && Array.isArray(tensor.external_data)) {
for (const entry of tensor.external_data) {
if (entry.key === 'location') {
locations.add(entry.value);
}
}
}
};
const graphs = new Set();
const queue = [ model.graph ];
while (queue.length > 0) {
const graph = queue.shift();
for (const initializer of graph.initializer) {
tensor(initializer);
if (Array.isArray(graph.initializer)) {
for (const initializer of graph.initializer) {
location(initializer);
}
}
for (const sparse_initializer of graph.sparse_initializer) {
tensor(sparse_initializer);
if (Array.isArray(graph.sparse_initializer)) {
for (const sparse_initializer of graph.sparse_initializer) {
location(sparse_initializer);
}
}
if (Array.isArray(graph.node)) {
for (const node of graph.node) {
if (Array.isArray(node.attribute)) {
for (const attribute of node.attribute) {
if (attribute.g) {
queue.push(attribute.g);
} else if (Array.isArray(attribute.graphs) && attribute.graphs.length > 0) {
queue.push(...attribute.graphs);
} else if (attribute.t) {
tensor(attribute.t);
} else if (Array.isArray(attribute.tensors) && attribute.tensors.length > 0) {
attribute.tensors.every((value) => tensor(value));
location(attribute.t);
} else if (attribute.sparse_tensor) {
tensor(attribute.sparse_tensor);
location(attribute.sparse_tensor);
} else if (Array.isArray(attribute.graphs) && attribute.graphs.length > 0) {
for (const graph of attribute.graphs) {
queue.push(graph);
}
} else if (Array.isArray(attribute.tensors) && attribute.tensors.length > 0) {
for (const tensor of attribute.tensors) {
location(tensor);
}
} else if (Array.isArray(attribute.sparse_tensors) && attribute.sparse_tensors.length > 0) {
attribute.sparse_tensors.every((value) => tensor(value));
for (const tensor of attribute.sparse_tensors) {
location(tensor);
}
}
}
}
}
}
graphs.add(graph);
}
const keys = Array.from(locations);
const streams = await Promise.all(keys.map((location) => context.request(location, null)));
const weights = new Map(keys.map((key, index) => [ key, streams[index] ]));
const weights = new Map();
try {
const keys = Array.from(locations);
const streams = await Promise.all(keys.map((location) => context.request(location, null)));
for (let i = 0; i < keys.length; i++) {
weights.set(keys[i], streams[i]);
}
} catch (error) {
// continue regardless of error
}
return new onnx.Model(metadata, format, model, Array.from(graphs), weights);
};
switch (target) {
Expand All @@ -240,17 +257,17 @@ onnx.ModelFactory = class {
const reader = protobuf.BinaryReader.open(stream);
const tensor = onnx.proto.TensorProto.decode(reader);
tensor.name = tensor.name || context.identifier;
const attribute = new onnx.proto.AttributeProto();
attribute.name = 'value';
attribute.type = onnx.AttributeType.TENSOR;
attribute.t = tensor;
const node = new onnx.proto.NodeProto();
node.op_type = 'Constant';
node.attribute = [ attribute ];
const graph = new onnx.proto.GraphProto();
graph.node = [ node ];
const model = new onnx.proto.ModelProto();
model.graph = new onnx.proto.GraphProto();
model.graph.initializer = [ tensor ];
model.graph.value_info = [ new onnx.proto.ValueInfoProto() ];
model.graph.value_info[0].name = tensor.name;
model.graph.node = [ new onnx.proto.NodeProto() ];
model.graph.node[0].op_type = 'Constant';
model.graph.node[0].attribute = [ new onnx.proto.AttributeProto() ];
model.graph.node[0].attribute[0].name = 'value';
model.graph.node[0].attribute[0].type = onnx.AttributeType.TENSOR;
model.graph.node[0].attribute[0].t = tensor;
model.graph = graph;
const format = 'ONNX Tensor';
return open(model, format);
} catch (error) {
Expand Down Expand Up @@ -1996,11 +2013,13 @@ onnx.Reader.text = class {
this._expect('=');
if (this._match('[')) {
const list = [];
do {
list.push(this._literal());
if (!this._match(']')) {
do {
list.push(this._literal());
}
while (this._match(','));
this._expect(']');
}
while (this._match(','));
this._expect(']');
if (list.every((value) => typeof value === 'string')) {
attribute.type = onnx.AttributeType.STRINGS;
attribute.strings = list;
Expand Down

0 comments on commit 21f9188

Please sign in to comment.