Skip to content

Commit

Permalink
Add TensorFrame.index_select functionality (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Aug 18, 2023
1 parent 28ed304 commit 9dad70a
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `TensorFrame.index_select` functionality ([#10](https://github.com/pyg-team/pytorch-frame/pull/10))
- Added `Dataset.to_tensor_frame` functionality ([#9](https://github.com/pyg-team/pytorch-frame/pull/9))
- Added base classes `TensorEncoder`, `FeatureEncoder`, `TableConv`, `Decoder` ([#5](https://github.com/pyg-team/pytorch-frame/pull/5))
- Added `TensorFrame` ([#4](https://github.com/pyg-team/pytorch-frame/pull/4))
Expand Down
38 changes: 33 additions & 5 deletions test/data/test_tensor_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,30 @@
from torch_frame import TensorFrame


def test_tensor_frame_basics():
def get_tensor_frame(num_rows: int) -> TensorFrame:
x_dict = {
torch_frame.categorical: torch.randint(0, 3, size=(10, 3)),
torch_frame.numerical: torch.randn(size=(10, 2)),
torch_frame.categorical: torch.randint(0, 3, size=(num_rows, 3)),
torch_frame.numerical: torch.randn(size=(num_rows, 2)),
}
col_names_dict = {
torch_frame.categorical: ['a', 'b', 'c'],
torch_frame.numerical: ['x', 'y'],
}
y = torch.randn(10)
y = torch.randn(num_rows)

return TensorFrame(x_dict=x_dict, col_names_dict=col_names_dict, y=y)

tf = TensorFrame(x_dict=x_dict, col_names_dict=col_names_dict, y=y)

def test_tensor_frame_basics():
tf = get_tensor_frame(num_rows=10)
assert tf.num_rows == 10

assert str(tf) == ("TensorFrame(\n"
" num_rows=10,\n"
" categorical (3): ['a', 'b', 'c'],\n"
" numerical (2): ['x', 'y'],\n"
")")


def test_tensor_frame_error():
x_dict = {
Expand Down Expand Up @@ -54,3 +63,22 @@ def test_tensor_frame_error():
y = torch.randn(11)
with pytest.raises(ValueError, match='not aligned'):
TensorFrame(x_dict=x_dict, col_names_dict=col_names_dict, y=y)


@pytest.mark.parametrize('index', [
4,
[4, 8],
range(2, 6),
torch.tensor([4, 8]),
])
def test_tensor_frame_index_select(index):
tf = get_tensor_frame(num_rows=10)

out = tf[index]

if isinstance(index, int):
assert out.num_rows == 1
else:
assert out.num_rows == len(index)

assert out.col_names_dict == tf.col_names_dict
28 changes: 27 additions & 1 deletion torch_frame/data/tensor_frame.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import copy
from dataclasses import dataclass
from typing import Dict, List, Optional

from torch import Tensor

import torch_frame
from torch_frame.typing import IndexSelectType


@dataclass
@dataclass(repr=False)
class TensorFrame:
r"""TensorFrame holds a :pytorch:`PyTorch` tensor for each table column.
Table columns are first organized into their semantic types (e.g.,
Expand Down Expand Up @@ -52,3 +54,27 @@ def __post_init__(self):
@property
def num_rows(self) -> int:
return len(next(iter(self.x_dict.values())))

def __repr__(self) -> str:
stype_repr = '\n'.join([
f' {stype.value} ({len(col_names)}): {col_names},'
for stype, col_names in self.col_names_dict.items()
])

return (f'{self.__class__.__name__}(\n'
f' num_rows={self.num_rows},\n'
f'{stype_repr}\n'
f')')

def __getitem__(self, index: IndexSelectType) -> 'TensorFrame':
if isinstance(index, int):
index = [index]

out = copy.copy(self)

out.x_dict = {stype: x[index] for stype, x in out.x_dict.items()}
out.col_names_dict = copy.copy(out.col_names_dict)
if out.y is not None:
out.y = out.y[index]

return out
5 changes: 5 additions & 0 deletions torch_frame/typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from typing import List, Union

import pandas as pd
from torch import Tensor

Series = pd.Series
DataFrame = pd.DataFrame

IndexSelectType = Union[int, List[int], range, slice, Tensor]

0 comments on commit 9dad70a

Please sign in to comment.