Skip to content

Commit

Permalink
Update pytorch.js (#1211)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Oct 20, 2024
1 parent 22d8cbe commit c0199f2
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 169 deletions.
149 changes: 104 additions & 45 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -6157,6 +6157,9 @@ python.Execution = class {
this._name = name;
}
}
static get(kind, name) {
return new torch.Type(kind, name);
}
kind() {
return this._kind;
}
Expand Down Expand Up @@ -6243,6 +6246,9 @@ python.Execution = class {
super('OptionalType');
this._elem = elem;
}
static get(elem) {
return new torch.OptionalType(elem);
}
getElementType() {
return this._elem;
}
Expand All @@ -6258,6 +6264,9 @@ python.Execution = class {
super('ListType');
this._elem = elem;
}
static get(elem) {
return new torch.ListType(elem);
}
getElementType() {
return this._elem;
}
Expand All @@ -6269,10 +6278,12 @@ python.Execution = class {
}
});
this.registerType('torch.FutureType', class extends torch.Type {
constructor(elem, size) {
constructor(elem) {
super('FutureType');
this._elem = elem;
this._size = size;
}
static get(elem) {
return new torch.FutureType(elem);
}
getElementType() {
return this._elem;
Expand All @@ -6289,6 +6300,9 @@ python.Execution = class {
super('RRefType');
this._elem = elem;
}
get(elem) {
return new torch.RRefType(elem);
}
getElementType() {
return this._elem;
}
Expand All @@ -6304,6 +6318,9 @@ python.Execution = class {
super('AwaitType');
this._elem = elem;
}
static get(elem) {
return new torch.AwaitType(elem);
}
getElementType() {
return this._elem;
}
Expand All @@ -6319,6 +6336,9 @@ python.Execution = class {
super('TupleType');
this._elements = elements;
}
static get(elements) {
return new torch.TupleType(elements);
}
elements() {
return this._elements;
}
Expand All @@ -6329,17 +6349,6 @@ python.Execution = class {
return `(${this.elements().map((elem) => elem.str()).join(', ')})`;
}
});
this.registerType('torch.TensorType', class extends torch.Type {
constructor() {
super('TensorType');
}
__str__() {
return 'Tensor';
}
str() {
return 'Tensor';
}
});
this.registerType('torch.AnyType', class extends torch.Type {
constructor() {
super('AnyType');
Expand All @@ -6356,10 +6365,29 @@ python.Execution = class {
return this.kind();
}
});
this.registerType('torch.TensorType', class extends torch.Type {
constructor() {
super('TensorType');
}
static get() {
torch.TensorType.value = torch.TensorType.value || new torch.TensorType();
return torch.TensorType.value;
}
__str__() {
return 'Tensor';
}
str() {
return 'Tensor';
}
});
this.registerType('torch.NumberType', class extends torch.Type {
constructor() {
super('NumberType');
}
static get() {
torch.NumberType.value = torch.NumberType.value || new torch.NumberType();
return torch.NumberType.value;
}
__str__() {
return 'number';
}
Expand All @@ -6371,6 +6399,10 @@ python.Execution = class {
constructor() {
super('BoolType');
}
static get() {
torch.BoolType.value = torch.BoolType.value || new torch.BoolType();
return torch.BoolType.value;
}
__str__() {
return 'bool';
}
Expand All @@ -6382,6 +6414,10 @@ python.Execution = class {
constructor() {
super('IntType');
}
static get() {
torch.IntType.value = torch.IntType.value || new torch.IntType();
return torch.IntType.value;
}
__str__() {
return 'int';
}
Expand All @@ -6393,6 +6429,10 @@ python.Execution = class {
constructor() {
super('SymIntType');
}
static get() {
torch.SymIntType.value = torch.SymIntType.value || new torch.SymIntType();
return torch.SymIntType.value;
}
__str__() {
return 'int';
}
Expand All @@ -6404,6 +6444,10 @@ python.Execution = class {
constructor() {
super('FloatType');
}
static get() {
torch.FloatType.value = torch.FloatType.value || new torch.FloatType();
return torch.FloatType.value;
}
__str__() {
return 'float';
}
Expand All @@ -6415,6 +6459,10 @@ python.Execution = class {
constructor() {
super('StringType');
}
static get() {
torch.StringType.value = torch.StringType.value || new torch.StringType();
return torch.StringType.value;
}
__str__() {
return 'str';
}
Expand All @@ -6426,6 +6474,10 @@ python.Execution = class {
constructor() {
super('ComplexType');
}
static get() {
torch.ComplexType.value = torch.ComplexType.value || new torch.ComplexType();
return torch.ComplexType.value;
}
__str__() {
return 'complex';
}
Expand All @@ -6439,6 +6491,9 @@ python.Execution = class {
this._key = key;
this._value = value;
}
static get(key, value) {
return new torch.DictType(key, value);
}
getKeyType() {
return this._key;
}
Expand Down Expand Up @@ -6478,6 +6533,10 @@ python.Execution = class {
constructor() {
super('GeneratorType');
}
static get() {
torch._C._GeneratorType.value = torch._C._GeneratorType.value || new torch._C._GeneratorType();
return torch._C._GeneratorType.value;
}
__str__() {
return 'Generator';
}
Expand Down Expand Up @@ -6613,26 +6672,26 @@ python.Execution = class {
const value = L.value;
L.next();
switch (value) {
case 'Tensor': return new torch.TensorType();
case 'bool': return new torch.BoolType();
case 'int': return new torch.IntType();
case 'float': return new torch.FloatType();
case 'complex': return new torch.ComplexType();
case 'str': return new torch.StringType();
case 'SymInt': return new torch.SymIntType();
case 'Scalar': return new torch.NumberType();
case 'ScalarType': return new torch.Type('ScalarTypeType');
case 'Tensor': return torch.TensorType.get();
case 'bool': return torch.BoolType.get();
case 'int': return torch.IntType.get();
case 'float': return torch.FloatType.get();
case 'complex': return torch.ComplexType.get();
case 'str': return torch.StringType.get();
case 'SymInt': return torch.SymIntType.get();
case 'Scalar': return torch.NumberType.get();
case 'ScalarType': return torch.Type.get('ScalarTypeType');
case 'Device': return new torch.DeviceObjType();
case 'Layout': return new torch.Type('Layout');
case 'MemoryFormat': return new torch.Type('MemoryFormat');
case 'Generator': return new torch._C._GeneratorType();
case 't': case 't1': case 't2': case 'tVal': return new torch.Type('VarType', value);
case 'Layout': return torch.Type.get('Layout');
case 'MemoryFormat': return torch.Type.get('MemoryFormat');
case 'Generator': return torch._C._GeneratorType.get();
case 't': case 't1': case 't2': case 'tVal': return torch.Type.get('VarType', value);
case 'Any': return new torch.AnyType();
case 'AnyEnumType': return new torch.Type('AnyEnumType');
case 'QScheme': return new torch.Type('QSchemeType');
case 'AnyEnumType': return torch.Type.get('AnyEnumType');
case 'QScheme': return torch.Type.get('QSchemeType');
case 'Stream': return new torch.StreamObjType();
case 'Storage': return new torch.Type('Storage');
case 'AnyClassType': return new torch.Type('AnyClassType');
case 'Storage': return torch.Type.get('Storage');
case 'AnyClassType': return torch.Type.get('AnyClassType');
case 'NoneType': return new torch.NoneType();
default: throw new python.Error(`Unsupported type '${value}'.`);
}
Expand All @@ -6655,7 +6714,7 @@ python.Execution = class {
L.eat(',');
L.whitespace(0);
}
real_value = new torch.TupleType(types);
real_value = torch.TupleType.get(types);
fake_value = real_value;
} else if (L.value === 'Future') {
L.next();
Expand All @@ -6664,7 +6723,7 @@ python.Execution = class {
const subtype = p.first;
// const subalias = p.second;
L.expect(')');
real_value = new torch.FutureType(subtype);
real_value = torch.FutureType.get(subtype);
fake_value = real_value;
} else if (L.value === 'Await') {
L.next();
Expand All @@ -6673,7 +6732,7 @@ python.Execution = class {
const subtype = p.first;
// const subalias = p.second;
L.expect(')');
real_value = new torch.AwaitType(subtype);
real_value = torch.AwaitType.get(subtype);
fake_value = real_value;
} else if (L.value === 'RRef') {
L.next();
Expand All @@ -6682,11 +6741,11 @@ python.Execution = class {
const subtype = p.first;
// const subalias = p.second;
L.expect(')');
real_value = new torch.RRefType(subtype);
real_value = torch.RRefType.get(subtype);
fake_value = real_value;
} else if (L.value === 'Tensor') {
L.next();
real_value = new torch.TensorType();
real_value = torch.TensorType.get();
fake_value = real_value;
alias_info = this.parseAliasAnnotation();
} else if (L.value === 'Dict') {
Expand All @@ -6698,7 +6757,7 @@ python.Execution = class {
const value_type = this.parseType().first;
L.expect(')');
alias_info = this.parseAliasAnnotation();
real_value = new torch.DictType(key_type, value_type);
real_value = torch.DictType.get(key_type, value_type);
fake_value = real_value;
} else if (L.eat('Union')) {
L.next();
Expand Down Expand Up @@ -6730,15 +6789,15 @@ python.Execution = class {
real_value.kind() === 'MemoryFormat' ||
real_value.kind() === 'Layout' ||
real_value.kind() === 'SymInt') {
fake_value = new torch.IntType();
fake_value = torch.IntType.get();
}
alias_info = this.parseAliasAnnotation();
}
while (true) {
if (L.kind === '[]') {
L.expect('[]');
fake_value = new torch.ListType(fake_value);
real_value = new torch.ListType(real_value);
fake_value = torch.ListType.get(fake_value);
real_value = torch.ListType.get(real_value);
let container = this.parseAliasAnnotation();
if (alias_info) {
if (!container) {
Expand All @@ -6749,8 +6808,8 @@ python.Execution = class {
}
alias_info = container;
} else if (L.eat('?')) {
fake_value = new torch.OptionalType(fake_value);
real_value = new torch.OptionalType(real_value);
fake_value = torch.OptionalType.get(fake_value);
real_value = torch.OptionalType.get(real_value);
} else {
break;
}
Expand Down Expand Up @@ -6807,8 +6866,8 @@ python.Execution = class {
L.whitespace(0);
let N = null;
if (L.eat('[')) {
fake_type = new torch.ListType(fake_type);
real_type = new torch.ListType(real_type);
fake_type = torch.ListType.get(fake_type);
real_type = torch.ListType.get(real_type);
if (L.kind === '#') {
N = Number(L.value);
L.next();
Expand All @@ -6825,9 +6884,9 @@ python.Execution = class {
alias_info = container;
if (L.eat('?')) {
/* eslint-disable no-unused-vars */
fake_type = new torch.OptionalType(fake_type);
fake_type = torch.OptionalType.get(fake_type);
/* eslint-enable no-unused-vars */
real_type = new torch.OptionalType(real_type);
real_type = torch.OptionalType.get(real_type);
}
}
let name = null;
Expand Down
Loading

0 comments on commit c0199f2

Please sign in to comment.