Skip to content

Commit

Permalink
Update pytorch.js (#637)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Oct 13, 2024
1 parent 8f2858e commit 7873d6e
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 5 deletions.
56 changes: 53 additions & 3 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -6154,6 +6154,12 @@ python.Execution = class {
kind() {
return this._kind;
}
__str__() {
throw new python.Error('Not implemented.');
}
toString() {
return this.__str__();
}
});
this.registerType('torch.ClassType', class extends torch.Type {
constructor(qualified_name, cu, is_module) {
Expand Down Expand Up @@ -6189,6 +6195,9 @@ python.Execution = class {
getElementType() {
return this._elem;
}
__str__() {
return `Optional[${this.getElementType().toString()}]`;
}
});
this.registerType('torch.ListType', class extends torch.Type {
constructor(elem, size) {
Expand All @@ -6201,6 +6210,9 @@ python.Execution = class {
getElementType() {
return this._elem;
}
__str__() {
return `List[${this.getElementType().toString()}]`;
}
});
this.registerType('torch.FutureType', class extends torch.Type {
constructor(elem, size) {
Expand All @@ -6213,16 +6225,32 @@ python.Execution = class {
}
});
this.registerType('torch.TupleType', class extends torch.Type {
constructor() {
constructor(elements) {
super('TupleType');
this._elements = elements;
}
elements() {
return this._elements;
}
});
this.registerType('torch.TensorType', class extends torch.Type {
constructor() {
super('TensorType');
}
__str__() {
return 'Tensor';
}
});
this.registerType('torch.AnyType', class extends torch.Type {
constructor() {
super('AnyType');
}
});
this.registerType('torch.NoneType', class extends torch.Type {
constructor() {
super('NoneType');
}
});
this.registerType('torch.AnyType', class extends torch.Type {});
this.registerType('torch.NumberType', class extends torch.Type {
constructor() {
super('NumberType');
Expand All @@ -6232,11 +6260,17 @@ python.Execution = class {
constructor() {
super('BoolType');
}
__str__() {
return 'bool';
}
});
this.registerType('torch.IntType', class extends torch.Type {
constructor() {
super('IntType');
}
__str__() {
return 'int';
}
});
this.registerType('torch.SymIntType', class extends torch.Type {
constructor() {
Expand All @@ -6247,16 +6281,25 @@ python.Execution = class {
constructor() {
super('FloatType');
}
__str__() {
return 'float';
}
});
this.registerType('torch.StringType', class extends torch.Type {
constructor() {
super('StringType');
}
__str__() {
return 'str';
}
});
this.registerType('torch.ComplexType', class extends torch.Type {
constructor() {
super('ComplexType');
}
__str__() {
return 'complex';
}
});
this.registerType('torch.DictType', class extends torch.Type {
constructor(key, value) {
Expand All @@ -6271,7 +6314,14 @@ python.Execution = class {
return this._value;
}
});
this.registerType('torch.DeviceObjType', class extends torch.Type {});
this.registerType('torch.DeviceObjType', class extends torch.Type {
constructor() {
super('DeviceObjType');
}
__str__() {
return 'Device';
}
});
this.registerType('torch._C._GeneratorType', class extends torch.Type {});
this.registerType('torch.Argument', class {
constructor(name, type, real_type, N, default_value, kwarg_only, alias_info) {
Expand Down
9 changes: 7 additions & 2 deletions source/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,12 @@ def _argument(self, argument, value):
argument_type = '[' + size + ']' + argument_type
value = value.element_type
elif isinstance(value, Schema.DictType):
value = str(value)
name = value.getKeyType().name
key_type = self._primitives[name] if name in self._primitives else name
name = value.getValueType().name
value_type = self._primitives[name] if name in self._primitives else name
value = f'Dict({key_type}, {value_type})'
argument_type = value
else:
name = value.name
name = self._primitives[name] if name in self._primitives else name
Expand Down Expand Up @@ -498,7 +503,7 @@ def __init__(self, key_type, value_type):
self._key_type = key_type
self._value_type = value_type
def __str__(self):
return 'Dict[' + str(self._key_type) + ', ' + str(self._value_type) + ']'
return 'Dict(' + str(self._key_type) + ', ' + str(self._value_type) + ')'
def getKeyType(self): # pylint: disable=invalid-name,missing-function-docstring
return self._key_type
def getValueType(self): # pylint: disable=invalid-name,,missing-function-docstring
Expand Down

0 comments on commit 7873d6e

Please sign in to comment.