Skip to content

Commit

Permalink
Update pytorch.js (#1061)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Nov 3, 2024
1 parent fb14c2e commit 25441ab
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 46 deletions.
216 changes: 171 additions & 45 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -6349,20 +6349,39 @@ python.Execution = class {
}
});
this.registerType('torch.TupleType', class extends torch.Type {
constructor(elements) {
super('TupleType');
constructor(elements, annotation_str, schema) {
super('TupleType', annotation_str);
this._elements = elements;
this._schema = schema;
}
static get(elements) {
return new torch.TupleType(elements);
}
static createNamed(qualified_name, field_names, field_types /*, field_defaults */) {
const args = [];
for (let i = 0; i < field_names.length; i++) {
const arg = new torch.Argument(field_names[i], field_types[i], field_types[i]);
args.push(arg);
}
const schema = new torch.FunctionSchema(qualified_name, args);
return new torch.TupleType(field_types, qualified_name, schema);
}
elements() {
return this._elements;
}
schema() {
return this._schema;
}
str() {
if (this._schema) {
return `NamedTuple(...)`;
}
return `(${this.elements().map((elem) => elem.str()).join(', ')})`;
}
__str__() {
if (this.annotation_str) {
return this.annotation_str;
}
return `Tuple[${this.elements().map((elem) => elem.__str__()).join(', ')}]`;
}
});
Expand Down Expand Up @@ -7074,11 +7093,11 @@ python.Execution = class {
const index = name.indexOf('(');
if (index === -1) {
this._name = name;
this._overload_name = overload_name;
this._arguments = args;
this._returns = returns;
this._is_vararg = is_vararg;
this._is_varret = is_varret;
this._overload_name = overload_name || '';
this._arguments = args || [];
this._returns = returns || [];
this._is_vararg = is_vararg || false;
this._is_varret = is_varret || false;
} else {
const value = name.substring(0, index).trim();
const dot = value.indexOf('.');
Expand Down Expand Up @@ -7689,22 +7708,32 @@ python.Execution = class {
this.register('torch.jit._script');
this.register('torch.jit._trace');
this.registerType('torch.jit.Source', class {
constructor(text) {
this._text = text;
constructor(text_view, filename) {
this._text_view = text_view;
this._filename = filename;
}
text_str() {
return this._text_view;
}
filename() {
return this._filename;
}
});
this.registerType('torch.jit.SourceLoader', class {
constructor(reader, code_prefix) {
this._reader = reader;
this._code_prefix = code_prefix;
this.registerType('torch.jit.QualifiedName', class {
constructor(name) {
const index = name.lastIndexOf('.');
this._qualifiedName = name;
this._prefix = index === -1 ? '' : name.substring(0, index);
this._name = index === -1 ? name : name.substring(index + 1);
}
loadSource(qualifier) {
const path = `${this._code_prefix}/${qualifier}.py`;
if (this._reader.has_record(path)) {
const data = this._reader.get_record(path);
return new torch.jit.Source(data);
}
return null;
qualifiedName() {
return this._qualifiedName; // "foo.bar.baz"
}
prefix() {
return this._prefix; // "foo.bar"
}
name() {
return this._name; // "baz"
}
});
this.registerType('torch.jit.SourceImporter', class {
Expand All @@ -7713,17 +7742,93 @@ python.Execution = class {
this._constant_table = constant_table;
this._source_loader = source_loader;
this._version = version;
this._loaded_sources = new Set();
this._to_be_defined = new Map();
}
loadType(/* name */) {
//
}
resolveType(name) {
return this.findNamedType(new torch.jit.QualifiedName(name));
name = new torch.jit.QualifiedName(name);
return this.findNamedType(name);
}
findNamedType(name) {
// if (auto custom_class = getCustomClass(name.qualifiedName())) {
// return custom_class;
// }
this.parseSourceIfNeeded(name.prefix());
const key = name.qualifiedName();
const it = this._to_be_defined.get(name.qualifiedName());
if (it && it.type === 'class') {
this._to_be_defined.delete(key);
this.importNamedType(name.prefix(), it);
}
return this._cu.get_type(name);
}
importNamedType(qualifier, class_def) {
const qualified_name = new torch.jit.QualifiedName(`${qualifier}.${class_def.name}`);
if (class_def.bases.length === 0) {
return this.importClass(qualified_name, class_def, false);
}
const superclass_name = class_def.bases[0].value;
if (superclass_name === 'Module') {
return this.importClass(qualified_name, class_def, true);
} else if (superclass_name === 'NamedTuple') {
return this.importNamedTuple(qualified_name, class_def);
} else if (superclass_name === 'Interface') {
// cu_->define_interface(qualified_name, class_def, shared_from_this(), is_module=false);
return null;
} else if (superclass_name === 'ModuleInterface') {
// cu_->define_interface(qualified_name, class_def, shared_from_this(), is_module=true);
return null;
} else if (superclass_name === 'Enum') {
// importEnum(qualified_name, class_def);
return null;
}
throw new python.Error('TorchScript does not support class inheritance.');
}
importClass(/* qualified_name, class_def, is_module */) {
return null;
}
importNamedTuple(qualified_name, named_tuple_def) {
const field_names = [];
const field_types = [];
const field_defaults = [];
for (const statement of named_tuple_def.body.statements) {
if (statement.type !== 'var') {
throw new python.Error('Unexpected statement in NamedTuple body.');
}
field_names.push(statement.name);
field_types.push(this._cu.execution.type(statement.variableType));
}
const tt = torch.TupleType.createNamed(qualified_name.qualifiedName(), field_names, field_types, field_defaults);
this._cu.register_type(tt);
}
parseSourceIfNeeded(/* qualifier */) {
parseSourceIfNeeded(qualifier) {
if (!qualifier || this._loaded_sources.has(qualifier)) {
return;
}
this._loaded_sources.add(qualifier);
const src = this._source_loader(qualifier);
if (!src) {
return;
}
const program = this._cu.execution.parse(src.filename(), src.text_str(), null);
for (const statement of program.body) {
switch (statement.type) {
case 'def': {
break;
}
case 'class': {
const name = `${qualifier}.${statement.name}`;
this._to_be_defined.set(name, statement);
break;
}
default: {
break;
}
}
}
}
});
this.registerType('torch.jit.ScriptModuleDeserializer', class {
Expand All @@ -7734,12 +7839,15 @@ python.Execution = class {
this._code_prefix = !pickle_dir_prefix && !tensor_dir_prefix ? 'code/' : '.data/ts_code/code/';
this._pickle_dir_prefix = pickle_dir_prefix || '';
this._tensor_dir_prefix = tensor_dir_prefix || '';
const SourceLoader = (qualifier) => {
return this.findSourceInArchiveFromQualifier(this._reader, this._code_prefix, qualifier);
};
this._source_importer = new torch.jit.SourceImporter(
this._compilation_unit, this._constants_table,
new torch.jit.SourceLoader(this._reader, this._code_prefix), reader.version());
this._compilation_unit, this._constants_table, SourceLoader, reader.version());
}
deserialize() {
const execution = this._compilation_unit.execution;
execution._resolver = this._source_importer;
const code_prefix = this._code_prefix;
for (const name of this._reader.get_all_records()) {
if (name.startsWith(code_prefix) && name.endsWith('.py')) {
Expand Down Expand Up @@ -7914,6 +8022,17 @@ python.Execution = class {
};
return unpickler.load();
}
qualifierToArchivePath(qualifier, export_prefix) {
return `${export_prefix}${qualifier.replace(/\./g, '/')}.py`;
}
findSourceInArchiveFromQualifier(reader, export_prefix, qualifier) {
const path = this.qualifierToArchivePath(qualifier, export_prefix);
if (!reader.has_record(path)) {
return null;
}
const data = reader.get_record(path);
return new torch.jit.Source(data.peek(), path);
}
});
this.registerType('torch.package.PackageImporter', class {
constructor(reader) {
Expand Down Expand Up @@ -8215,6 +8334,9 @@ python.Execution = class {
this._functions = new Map();
this._classes = new Map();
}
register_type(namedType) {
this._classes.set(namedType.annotation_str, namedType);
}
register_function(fn) {
this._functions.set(fn.name, fn);
}
Expand All @@ -8228,14 +8350,11 @@ python.Execution = class {
}
}
get_type(name) {
return this._classes.get(name);
return this._classes.get(name.qualifiedName());
}
get_class(name) {
return this.get_type(name);
}
register_type(name, cls) {
this._classes.set(name, cls);
}
});
this.registerType('torch.jit._script.ScriptModule', class extends torch.nn.modules.module.Module {});
this.registerType('torch.jit._trace.TracedModule', class extends torch.jit._script.ScriptModule {});
Expand Down Expand Up @@ -8399,7 +8518,7 @@ python.Execution = class {
if (!cls) {
const name = obj_type.type_name;
if (name.startsWith('__torch__') || name.startsWith('torch.jit')) {
cls = this._cu.get_class(name);
cls = this._cu.get_class(new torch.jit.QualifiedName(name));
if (!cls) {
const torch = this._torch;
cls = new torch.ClassType(name, this._cu, true);
Expand Down Expand Up @@ -10247,13 +10366,6 @@ python.Execution = class {
return this._builtins;
}

source(file) {
return this._sources.has(file) ? this._sources.get(file) : null;
}

debug(/* file */) {
}

exec(code , context) {
const reader = new python.Parser(code, '', null);
const program = reader.parse();
Expand All @@ -10263,21 +10375,35 @@ python.Execution = class {
this.block(program.body, context);
}

parse(file) {
debug(/* file */) {
}

source(file) {
if (this._sources.has(file)) {
return this._sources.get(file);
}
return null;
}

read(file) {
const buffer = this.source(file);
if (buffer) {
const debug = this.debug(file);
const code = this._utf8Decoder.decode(buffer);
const parser = new python.Parser(code, file, debug);
const program = parser.parse();
if (!program) {
throw new python.Error(`Module '${file}' parse error.`);
}
return program;
return this.parse(file, buffer, debug);
}
return null;
}

parse(file, buffer, debug) {
const code = this._utf8Decoder.decode(buffer);
const parser = new python.Parser(code, file, debug);
const program = parser.parse();
if (!program) {
throw new python.Error(`Module '${file}' parse error.`);
}
return program;
}

import(name, current, level) {
if (level) {
let bits = current.split('.');
Expand All @@ -10303,7 +10429,7 @@ python.Execution = class {
const path = name.split('.').join('/');
module.__path__ = [path];
const file = `${path}.py`;
const program = this.parse(file);
const program = this.read(file);
if (program) {
module.__file__ = file;
for (const [name, value] of Object.entries(this.builtins)) {
Expand Down
18 changes: 18 additions & 0 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -2639,6 +2639,24 @@ pytorch.Execution = class extends python.Execution {
}
return node.addOutput();
}
const prefix = pytorch.Utility.target(target);
if (prefix && prefix !== 'self' && !prefix.startsWith('self.') && prefix.indexOf('.') !== -1) {
const identifier = `${prefix}.${name}`;
const type = this._resolver.resolveType(identifier);
if (type instanceof torch.TupleType) {
const node = this._graph.create('prim::TupleConstruct');
node.setSourceRange(location);
this.graph.insertNode(node);
const evalArgs = args.map((expression) => this.expression(expression, context));
for (const arg of evalArgs) {
const value = this.variable(arg);
node.addInput(value);
}
const output = node.addOutput();
output.setType(type);
return output;
}
}
return super.call(target, name, args, context);
}
const [schema, evalArgs] = overload;
Expand Down
2 changes: 1 addition & 1 deletion test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -5866,7 +5866,7 @@
"type": "pytorch",
"target": "pyg_model.pt",
"source": "https://github.com/lutzroeder/netron/files/10369483/pyg_model.zip[pyg_model.pt]",
"error": "Unknown function 'aten::linear'.",
"error": "Cannot read properties of undefined (reading 'str')",
"link": "https://github.com/lutzroeder/netron/issues/546"
},
{
Expand Down

0 comments on commit 25441ab

Please sign in to comment.