Skip to content

Commit

Permalink
Add is_floating_point() to multi tensors (#445)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
akihironitta and pre-commit-ci[bot] authored Sep 6, 2024
1 parent 200b962 commit 63cafb7
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `is_floating_point` method to `MultiNestedTensor` and `MultiEmbeddingTensor` ([#445](https://github.com/pyg-team/pytorch-frame/pull/445))
- Added support for inferring `stype.categorical` from boolean columns in `utils.infer_series_stype` ([#421](https://github.com/pyg-team/pytorch-frame/pull/421))

### Changed
Expand Down
11 changes: 11 additions & 0 deletions test/data/test_multi_embedding_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,17 @@ def test_size():
assert met.shape[2] == -1


def test_is_floating_point():
met = MultiEmbeddingTensor.from_tensor_list([
torch.tensor([[1]], dtype=torch.long),
])
assert not met.is_floating_point()
met = MultiEmbeddingTensor.from_tensor_list([
torch.tensor([[1]], dtype=torch.float32),
])
assert met.is_floating_point()


def test_fillna_col():
# Creat a MultiEmbeddingTensor containing all -1's
# In MultiEmbeddingTensor with torch.long dtype,
Expand Down
9 changes: 9 additions & 0 deletions test/data/test_multi_nested_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ def column_select(
return new_tensor_mat


def test_is_floating_point():
met = MultiNestedTensor.from_tensor_mat(
[[torch.tensor([1], dtype=torch.long)]])
assert not met.is_floating_point()
met = MultiNestedTensor.from_tensor_mat(
[[torch.tensor([1], dtype=torch.float32)]])
assert met.is_floating_point()


def test_fillna_col():
# Creat a MultiNestedTensor containing all -1's
# In MultiNestedTensor with torch.long dtype,
Expand Down
3 changes: 3 additions & 0 deletions torch_frame/data/multi_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def device(self) -> torch.device:
def dtype(self) -> torch.dtype:
return self.values.dtype

def is_floating_point(self) -> bool:
return self.values.is_floating_point()

def clone(self) -> _MultiTensor:
return self.__class__(
self.num_rows,
Expand Down

0 comments on commit 63cafb7

Please sign in to comment.