From 9dad70a2f4a0740ffe3a3842c067d50afd00d038 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Fri, 18 Aug 2023 05:52:27 +0200 Subject: [PATCH] Add `TensorFrame.index_select` functionality (#10) --- CHANGELOG.md | 1 + test/data/test_tensor_frame.py | 38 +++++++++++++++++++++++++++----- torch_frame/data/tensor_frame.py | 28 ++++++++++++++++++++++- torch_frame/typing.py | 5 +++++ 4 files changed, 66 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d8588395..88f7eed2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `TensorFrame.index_select` functionality ([#10](https://github.com/pyg-team/pytorch-frame/pull/10)) - Added `Dataset.to_tensor_frame` functionality ([#9](https://github.com/pyg-team/pytorch-frame/pull/9)) - Added base classes `TensorEncoder`, `FeatureEncoder`, `TableConv`, `Decoder` ([#5](https://github.com/pyg-team/pytorch-frame/pull/5)) - Added `TensorFrame` ([#4](https://github.com/pyg-team/pytorch-frame/pull/4)) diff --git a/test/data/test_tensor_frame.py b/test/data/test_tensor_frame.py index 4329595e..53fd99ab 100644 --- a/test/data/test_tensor_frame.py +++ b/test/data/test_tensor_frame.py @@ -5,21 +5,30 @@ from torch_frame import TensorFrame -def test_tensor_frame_basics(): +def get_tensor_frame(num_rows: int) -> TensorFrame: x_dict = { - torch_frame.categorical: torch.randint(0, 3, size=(10, 3)), - torch_frame.numerical: torch.randn(size=(10, 2)), + torch_frame.categorical: torch.randint(0, 3, size=(num_rows, 3)), + torch_frame.numerical: torch.randn(size=(num_rows, 2)), } col_names_dict = { torch_frame.categorical: ['a', 'b', 'c'], torch_frame.numerical: ['x', 'y'], } - y = torch.randn(10) + y = torch.randn(num_rows) + + return TensorFrame(x_dict=x_dict, col_names_dict=col_names_dict, y=y) - tf = TensorFrame(x_dict=x_dict, col_names_dict=col_names_dict, y=y) +def test_tensor_frame_basics(): + tf = get_tensor_frame(num_rows=10) assert tf.num_rows == 10 + assert str(tf) == ("TensorFrame(\n" + " num_rows=10,\n" + " categorical (3): ['a', 'b', 'c'],\n" + " numerical (2): ['x', 'y'],\n" + ")") + def test_tensor_frame_error(): x_dict = { @@ -54,3 +63,22 @@ def test_tensor_frame_error(): y = torch.randn(11) with pytest.raises(ValueError, match='not aligned'): TensorFrame(x_dict=x_dict, col_names_dict=col_names_dict, y=y) + + +@pytest.mark.parametrize('index', [ + 4, + [4, 8], + range(2, 6), + torch.tensor([4, 8]), +]) +def test_tensor_frame_index_select(index): + tf = get_tensor_frame(num_rows=10) + + out = tf[index] + + if isinstance(index, int): + assert out.num_rows == 1 + else: + assert out.num_rows == len(index) + + assert out.col_names_dict == tf.col_names_dict diff --git a/torch_frame/data/tensor_frame.py b/torch_frame/data/tensor_frame.py index 9d51ae90..d4cdebc8 100644 --- a/torch_frame/data/tensor_frame.py +++ b/torch_frame/data/tensor_frame.py @@ -1,12 +1,14 @@ +import copy from dataclasses import dataclass from typing import Dict, List, Optional from torch import Tensor import torch_frame +from torch_frame.typing import IndexSelectType -@dataclass +@dataclass(repr=False) class TensorFrame: r"""TensorFrame holds a :pytorch:`PyTorch` tensor for each table column. Table columns are first organized into their semantic types (e.g., @@ -52,3 +54,27 @@ def __post_init__(self): @property def num_rows(self) -> int: return len(next(iter(self.x_dict.values()))) + + def __repr__(self) -> str: + stype_repr = '\n'.join([ + f' {stype.value} ({len(col_names)}): {col_names},' + for stype, col_names in self.col_names_dict.items() + ]) + + return (f'{self.__class__.__name__}(\n' + f' num_rows={self.num_rows},\n' + f'{stype_repr}\n' + f')') + + def __getitem__(self, index: IndexSelectType) -> 'TensorFrame': + if isinstance(index, int): + index = [index] + + out = copy.copy(self) + + out.x_dict = {stype: x[index] for stype, x in out.x_dict.items()} + out.col_names_dict = copy.copy(out.col_names_dict) + if out.y is not None: + out.y = out.y[index] + + return out diff --git a/torch_frame/typing.py b/torch_frame/typing.py index d807edf9..5cebc27f 100644 --- a/torch_frame/typing.py +++ b/torch_frame/typing.py @@ -1,4 +1,9 @@ +from typing import List, Union + import pandas as pd +from torch import Tensor Series = pd.Series DataFrame = pd.DataFrame + +IndexSelectType = Union[int, List[int], range, slice, Tensor]