Skip to content

Commit

Permalink
Add base classes (#7)
Browse files Browse the repository at this point in the history
Add base classes for `TensorEncoder`, `FeatureEncoder`, `TableConv`,
`Decoder`.
Also add a simple test that instantiates each class and test e2e
pipeline.

---------

Co-authored-by: Matthias Fey <[email protected]>
  • Loading branch information
weihua916 and rusty1s authored Aug 16, 2023
1 parent beb3089 commit 265451d
Show file tree
Hide file tree
Showing 12 changed files with 237 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `Titanic` dataset ([#3](https://github.com/pyg-team/pytorch-frame/pull/3))
- Added `Dataset` base class ([#3](https://github.com/pyg-team/pytorch-frame/pull/3))
- Added `TensorFrame` ([#4](https://github.com/pyg-team/pytorch-frame/pull/4))
- Added base classes `TensorEncoder`, `FeatureEncoder`, `TableConv`, `Decoder` ([#5](https://github.com/pyg-team/pytorch-frame/pull/5))

### Changed

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ classifiers=[
"Programming Language :: Python :: 3 :: Only",
]
dependencies=[
"numpy",
"pandas",
"torch",
]

[project.optional-dependencies]
Expand Down
108 changes: 108 additions & 0 deletions test/nn/test_simple_basecls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from typing import Any, Dict, List, Tuple
from collections import defaultdict

import numpy as np
import torch
from torch import Tensor

from torch_frame.nn import FeatureEncoder, TableConv, Decoder
from torch_frame.encoder import TensorEncoder
from torch_frame import TensorFrame, stype
from torch_frame.typing import DataFrame


def test_simple_basecls():
# Instantiate each base class with a simple class and test e2e pipeline.
class SimpleTensorEncoder(TensorEncoder):
def to_tensor(self, df: DataFrame) -> TensorFrame:
x_list_dict: Dict[stype, List[Tensor]] = defaultdict(lambda: [])
col_names_dict: Dict[stype, List[str]] = defaultdict(lambda: [])

for col_name in df.keys():
stype_name = self.col2stype[col_name]
tensor = torch.from_numpy(df[col_name].values).view(-1, 1)
if stype_name == stype.categorical:
x_list_dict[stype_name].append(tensor.to(torch.long))
else:
x_list_dict[stype_name].append(tensor.to(torch.float))
col_names_dict[stype_name].append(col_name)

x_dict: Dict[stype, Tensor] = {
stype_name: torch.cat(x_list, dim=1)
for stype_name, x_list in x_list_dict.items()
}
return TensorFrame(x_dict=x_dict, col_names_dict=col_names_dict)

class SimpleFeatureEncoder(FeatureEncoder):
def __init__(
self,
out_channels: int,
num_numerical: int,
num_categories: List[int],
):
super().__init__()
self.lin_numerical = torch.nn.Linear(num_numerical, out_channels)
self.emb_categorical = torch.nn.ModuleList([
torch.nn.Embedding(num_category, out_channels)
for num_category in num_categories
])

def forward(self, tf: TensorFrame) -> Tuple[Tensor, List[str]]:
x_num = self.lin_numerical(
tf.x_dict[stype.numerical].unsqueeze(dim=2))
num_cat_cols = tf.x_dict[stype.categorical].size(1)
x_cat_list = []
for i in range(num_cat_cols):
x_cat: Tensor = self.emb_categorical[i](
tf.x_dict[stype.categorical][:, i])
x_cat_list.append(x_cat.unsqueeze(dim=1))
x_cat = torch.cat(x_cat_list, dim=1)
x = torch.cat([x_num, x_cat], dim=1)
col_names = tf.col_names_dict[stype.numerical] + tf.col_names_dict[
stype.categorical]
return x, col_names

class SimpleTableConv(TableConv):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.lin = torch.nn.Linear(in_channels, out_channels)

def forward(self, x: Tensor) -> Tensor:
B, C, H = x.shape
x = x.view(-1, H)
return self.lin(x).view(B, C, -1)

class SimpleDecoder(Decoder):
def forward(self, x: Tensor) -> Tensor:
return torch.mean(x, dim=-1)

df = DataFrame({
'num1': np.random.randn(10),
'num2': np.random.randn(10),
'cat1': np.random.randint(0, 3, 10),
'cat2': np.random.randint(0, 5, 10),
})
tensor_encoder = SimpleTensorEncoder(
col2stype={
'num1': stype.numerical,
'num2': stype.numerical,
'cat1': stype.categorical,
'cat2': stype.categorical,
})
feat_encoder = SimpleFeatureEncoder(out_channels=8, num_numerical=1,
num_categories=[3, 5])
table_conv1 = SimpleTableConv(in_channels=8, out_channels=16)
table_conv2 = SimpleTableConv(in_channels=16, out_channels=8)
decoder = SimpleDecoder()

tf = tensor_encoder.to_tensor(df)
x, col_names = feat_encoder(tf)
# [batch_size, num_cols, hidden_channels]
assert x.shape == (10, 4, 8)
assert col_names == list(df.keys())
x = table_conv1(x)
assert x.shape == (10, 4, 16)
x = table_conv2(x)
assert x.shape == (10, 4, 8)
x = decoder(x)
assert x.shape == (10, 4)
5 changes: 5 additions & 0 deletions torch_frame/encoder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .encoder import TensorEncoder

__all__ = [
'TensorEncoder',
]
25 changes: 25 additions & 0 deletions torch_frame/encoder/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from abc import ABC, abstractmethod
from typing import Dict

from torch_frame import TensorFrame, stype
from torch_frame.typing import DataFrame


class TensorEncoder(ABC):
r"""Base class for tensor encoder that transforms input DataFrame into
TensorFrame.
Args:
col2stype: (Dict[str, stype]): A dictionary that maps column name in
DataFrame to its stype.
"""
def __init__(
self,
col2stype: Dict[str, stype],
):
self.col2stype = col2stype

@abstractmethod
def to_tensor(self, df: DataFrame) -> TensorFrame:
r"""Convert DataFrame into TensorFrame"""
raise NotImplementedError
9 changes: 9 additions & 0 deletions torch_frame/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .encoder import FeatureEncoder
from .conv import TableConv
from .decoder import Decoder

__all__ = [
'FeatureEncoder',
'TableConv',
'Decoder',
]
5 changes: 5 additions & 0 deletions torch_frame/nn/conv/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .table_conv import TableConv

__all__ = [
'TableConv',
]
23 changes: 23 additions & 0 deletions torch_frame/nn/conv/table_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from abc import ABC, abstractmethod
from typing import Any

from torch import Tensor
from torch.nn import Module


class TableConv(Module, ABC):
r"""Base class for table convolution that transforms the input column-wise
pytorch tensor.
"""
@abstractmethod
def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Any:
r"""Process column-wise 3-dimensional tensor into another column-wise
3-dimensional tensor.
Args:
x (Tensor): Input column-wise tensor of shape
:obj:`[batch_size, num_cols, hidden_channels]`.
args (Any): Extra arguments.
kwargs (Any): Extra keyward arguments.
"""
raise NotImplementedError
5 changes: 5 additions & 0 deletions torch_frame/nn/decoder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .decoder import Decoder

__all__ = [
'Decoder',
]
21 changes: 21 additions & 0 deletions torch_frame/nn/decoder/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from abc import ABC, abstractmethod
from typing import Any

from torch import Tensor
from torch.nn import Module


class Decoder(Module, ABC):
r"""Base class for decoder that transforms the input column-wise pytorch
tensor into output tensor on which prediction head is applied.
"""
@abstractmethod
def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Any:
r"""Encode TensorFrame into (x, col_names).
Args:
x (Tensor): Input column-wise tensor of shape
:obj:`[batch_size, num_cols, hidden_channels]`.
args (Any): Extra arguments.
kwargs (Any): Extra keyward arguments.
"""
raise NotImplementedError
5 changes: 5 additions & 0 deletions torch_frame/nn/encoder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .encoder import FeatureEncoder

__all__ = [
'FeatureEncoder',
]
28 changes: 28 additions & 0 deletions torch_frame/nn/encoder/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from abc import ABC, abstractmethod
from typing import Tuple, List

from torch import Tensor
from torch.nn import Module

from torch_frame import TensorFrame


class FeatureEncoder(Module, ABC):
r"""Base class for feature encoder that transforms input TensorFrame into
:obj:`(x, col_names)`, where :obj:`x` is the colum-wise pytorch tensor and
:obj:`col_names` is the names of the columns. This class can contain
learnable parameters and missing value handling.
"""
@abstractmethod
def forward(self, df: TensorFrame) -> Tuple[Tensor, List[str]]:
r"""Encode TensorFrame into (x, col_names).
Args:
df (TensorFrame): Input TensorFrame
Returns:
x (Tensor): Output column-wise pytorch tensor of shape
:obj:`[batch_size, num_cols, hidden_channels]`.
col_names (List[str]): Column names of :obj:`x`. The length needs
to be :obj:`num_cols`.
"""
raise NotImplementedError

0 comments on commit 265451d

Please sign in to comment.