Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support for a scatter 'concatenate' or 'groupby' operation #398

Open
davidbuterez opened this issue Oct 11, 2023 · 10 comments
Open

Comments

@davidbuterez
Copy link

Hi, thanks for the amazing work so far!

I was wondering if it would be possible to efficiently support a scatter operation that instead of reducing (e.g. using sum, mean, max, or min), simply returns the values indicated by the index.

For example, following the homepage illustration of this repo:

index = [0, 0, 1, 0, 2, 2, 3, 3]
input = [5, 1, 7, 2, 3, 2, 1, 3]

I would like to get an output similar to this:

0: [5, 1, 2]
1: [7]
2: [3, 2]
3: [1, 3]

(the order within each list would not matter)

I am not sure if I am missing something or if this is possible using existing operations. Perhaps the varying length is problematic, but this could be handled with nested tensors or padding. I would like to apply this operation several times per training epoch so ideally it would be efficient on GPUs.

@rusty1s
Copy link
Owner

rusty1s commented Oct 11, 2023

This does refer to the CSR representation of a sparse matrix, which is implemented in PyG:

edge_index = torch.stack([index, input], dim=0)

from torch_geometric.utils import to_torch_csr_tensor

sparse_mat = to_torch_csr_tensor(edge_index)
print(sparse_mat.crow_indices())
print(sparse_mat.col_indices())

Copy link

github-actions bot commented Apr 9, 2024

This issue had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity. Is this issue already resolved?

@github-actions github-actions bot added the stale label Apr 9, 2024
@JosephDenman
Copy link

@rusty1s Is there a simple way to retrieve the outputs that @davidbuterez provided:

0: [5, 1, 2]
1: [7]
2: [3, 2]
3: [1, 3]

from the result of to_torch_csr_tensor(edge_index)? This concatenation aggregation is something I'm working on currently, but I can't actually find an implementation online.

@github-actions github-actions bot removed the stale label Apr 21, 2024
@rusty1s
Copy link
Owner

rusty1s commented Apr 21, 2024

count = rowptr.diff()
col.split(count.tolist())

@JosephDenman
Copy link

I have a PR with a draft implementation if you give me push access to the repo. I'll push to a feature branch I guess? Unsure what the contribution guidelines are.

@JosephDenman
Copy link

@rusty1s What's the correct way to contribute?

@rusty1s
Copy link
Owner

rusty1s commented Apr 25, 2024

What do you want to contribute exactly? It looks like groupby is already well supported.

@JosephDenman
Copy link

JosephDenman commented Apr 25, 2024

Hmm. Unsure what you mean. It's an implementation of Aggregation that just concatenates the features of neighboring nodes using this function.

def concat_group_by(x: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
    index_count = torch.bincount(index)
    fill_count = index_count.max() - index_count
    fill_zeros = torch.zeros_like(x[0]).repeat(fill_count.sum(), *([1]*(len(x.shape)-1)))
    fill_index = torch.arange(0, fill_count.shape[0]).repeat_interleave(fill_count)
    index_ = torch.cat([index, fill_index], dim=0)
    x_ = torch.cat([x, fill_zeros], dim=0)
    x_ = x_[torch.argsort(index_, stable=True)].view(index_count.shape[0], index_count.max(), *x.shape[1:])
    return x_

Inputs:

t = torch.tensor([[5, 50], [6, 60], [7, 70], [8, 80], [9, 90], [10, 100], [11, 110]])
index = torch.tensor([0, 0, 2, 2, 2, 1, 1])

Output:

tensor([[[  5,  50],
         [  6,  60],
         [  0,   0]],

        [[ 10, 100],
         [ 11, 110],
         [  0,   0]],

        [[  7,  70],
         [  8,  80],
         [  9,  90]]])

AFAIK groupby doesn't exist in PyG, and implementations elsewhere only support groupby using arithmetic operations, sum, mean, etc. This is useful when you want to gather all neighbor node features into the feature of the center node. You can, e.g., do softmax over the result to obtain probabilities of selecting neighbors.

@rusty1s
Copy link
Owner

rusty1s commented May 7, 2024

Wouldn't this be similar to MLPAggregation in PyG?

@wzm2256
Copy link

wzm2256 commented Sep 27, 2024

For anyone who read this, I found that the closeset solution to something like scatter(src, index, dim, reduce='group') is the to_dense_batch function. I provide a minimal user case using the data provided by @davidbuterez as follows.

from torch_geometric.utils import to_dense_batch
import torch
batch = torch.tensor([0, 0, 1, 0, 2, 2, 3, 3])
x = torch.tensor([5, 1, 7, 2, 3, 2, 1, 3])
sorted_batch, indices = torch.sort(batch, 0)
sorted_x = x[indices]

to_dense_batch(sorted_x, sorted_batch)

output:
(tensor([[5, 1, 2],
        [7, 0, 0],
        [3, 2, 0],
        [1, 3, 0]]), tensor([[ True,  True,  True],
        [ True, False, False],
        [ True,  True, False],
        [ True,  True, False]]))

However, batch MUST be sorted in this function. More explanations can be found in the document.

In addition, I tried to_torch_csr_tensor, but crow_indices and col_indices do not output anything similar to the required output. Another reply suggests rowptr.diff(), but it is not clear what is that function. Thanks for pointing out MLPAggregation layer. I found the desired to_dense_batch in the source code of this layer.

@rusty1s I suggest incoporating this function into the scatter function in this library for completeness.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants