Skip to content

Commit

Permalink
Update pytorch.js (#1061)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Oct 21, 2024
1 parent e97c8a1 commit cee1b00
Show file tree
Hide file tree
Showing 5 changed files with 450 additions and 207 deletions.
31 changes: 18 additions & 13 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -4377,6 +4377,10 @@ python.Execution = class {
}
throw new python.Error(`Schema '${op_name}.${overload_name}' not found.`);
});
this.registerFunction('torch._C._jit_get_schemas_for_operator', (op_name) => {
const registry = torch._C._get_registry();
return registry.getAllOperatorsFor(op_name).map((op) => op.schema());
});
this.registerFunction('torch._C._jit_get_operation', (op_name) => {
const registry = torch._C._get_registry();
const sortedOps = registry.getAllOperatorsFor(op_name);
Expand Down Expand Up @@ -6152,14 +6156,14 @@ python.Execution = class {

});
this.registerType('torch.Type', class {
constructor(kind, name) {
constructor(kind, annotation_str) {
this._kind = kind;
if (name) {
this._name = name;
if (annotation_str) {
this._annotation_str = annotation_str;
}
}
static get(kind, name) {
return new torch.Type(kind, name);
static get(kind, annotation_str) {
return new torch.Type(kind, annotation_str);
}
kind() {
return this._kind;
Expand All @@ -6171,8 +6175,8 @@ python.Execution = class {
throw new python.Error(`Not implemented '${this.kind()}'.`);
}
str() {
if (this._kind === 'VarType' && this._name) {
return this._name;
if (this._kind === 'VarType' && this._annotation_str) {
return this._annotation_str;
} else if (this._kind === 'ScalarTypeType') {
return 'ScalarType';
} else if (this._kind === 'QSchemeType') {
Expand Down Expand Up @@ -6722,6 +6726,7 @@ python.Execution = class {
case 't': case 't1': case 't2': case 'tVal': return torch.Type.get('VarType', value);
case 'Any': return torch.AnyType.get();
case 'AnyEnumType': return torch.Type.get('AnyEnumType');
case 'Dimname': return torch.StringType.get();
case 'QScheme': return torch.Type.get('QSchemeType');
case 'Stream': return torch.StreamObjType.get();
case 'Storage': return torch.Type.get('Storage');
Expand Down Expand Up @@ -7036,7 +7041,7 @@ python.Execution = class {
});
this.registerType('torch.FunctionSchema', class {
constructor(name, overload_name, args, returns, is_vararg, is_varret) {
let index = name.indexOf('(');
const index = name.indexOf('(');
if (index === -1) {
this._name = name;
this._overload_name = overload_name;
Expand All @@ -7046,15 +7051,15 @@ python.Execution = class {
this._is_varret = is_varret;
} else {
const value = name.substring(0, index).trim();
this._buffer = name.substring(index, name.length);
index = value.indexOf('.');
if (index === -1) {
const dot = value.indexOf('.');
if (dot === -1) {
this._name = value;
this._overload_name = '';
} else {
this._name = value.substring(0, index);
this._overload_name = value.substring(index + 1, value.length);
this._name = value.substring(0, dot);
this._overload_name = value.substring(dot + 1, value.length);
}
this._buffer = name.substring(index, name.length);
}
}
static parse(schema) {
Expand Down
29 changes: 29 additions & 0 deletions source/pytorch-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,16 @@
{ "type": "Tensor" }
]
},
{
"name": "aten::__is__(t1 self, t2 obj) -> bool",
"inputs": [
{ "name": "self", "type": "t1" },
{ "name": "obj", "type": "t2" }
],
"outputs": [
{ "type": "boolean" }
]
},
{
"name": "aten::__isnot__(t1 self, t2 obj) -> bool",
"inputs": [
Expand Down Expand Up @@ -5084,6 +5094,25 @@
{ "type": "Tensor" }
]
},
{
"name": "aten::device(str a) -> Device",
"inputs": [
{ "name": "a", "type": "string" }
],
"outputs": [
{ "type": "Device" }
]
},
{
"name": "aten::device.with_index(str type, int index) -> Device",
"inputs": [
{ "name": "type", "type": "string" },
{ "name": "index", "type": "int64" }
],
"outputs": [
{ "type": "Device" }
]
},
{
"name": "aten::diag(Tensor self, int diagonal=0) -> Tensor",
"inputs": [
Expand Down
Loading

0 comments on commit cee1b00

Please sign in to comment.