Skip to content

Commit

Permalink
Support tensor storage for compat interface
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhangxyz committed Nov 21, 2023
1 parent 16396f5 commit 1ed7043
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions tat/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,14 @@ def _(self: T, position: dict[str, tuple[typing.Any, int]]) -> tuple[int, ...]:
# Function renames


def _compat_function(focus_type: type) -> typing.Callable[[typing.Callable], typing.Callable]:
def _compat_function(focus_type: type, name: str | None = None) -> typing.Callable[[typing.Callable], typing.Callable]:

def _result(function: typing.Callable) -> typing.Callable:
name = function.__name__
setattr(focus_type, name, function)
if name is None:
attr_name = function.__name__
else:
attr_name = name
setattr(focus_type, attr_name, function)
return function

return _result
Expand All @@ -303,6 +306,14 @@ def zero(self: T) -> T:
return self.zero_()


@_compat_function(T, name="storage") # type: ignore[misc]
@property
def storage(self: T) -> typing.Any:
"Get the storage of the tensor"
assert self.data.is_contiguous()
return self.data.numpy().reshape([-1])


# Exponential arguments

origin_exponential = T.exponential
Expand Down

0 comments on commit 1ed7043

Please sign in to comment.