Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Empty second-order derivative (= Hessian) for the segment_* reductions #299

Open
jeanfeydy opened this issue May 19, 2022 · 3 comments
Open
Labels

Comments

@jeanfeydy
Copy link

Hi @rusty1s,
Thanks again for your great work on this library!

I am currently experimenting with computing second-order derivatives that involve torch_scatter operations, and noticed that the segment_coo and segment_csr operators are not twice differentiable with the sum reduction. To reproduce this behavior, see e.g.:

import torch
import torch_scatter

# Values:
val = torch.FloatTensor([[0, 1, 2]])
# Groups:
gr_coo = torch.LongTensor([[0, 0, 1]])
gr_csr = torch.LongTensor([[0, 2, 3]])

val.requires_grad = True
B, D = val.shape


def group_reduce(*, values, groups, reduction, output_size, backend):

    if backend == "torch":
        # Compatibility switch for PyTorch.scatter_reduce:
        if reduction == "max":
            reduction = "amax"
        return torch.scatter_reduce(
            values, 1, groups, reduction, output_size=output_size
        )
    elif backend == "pyg":
        return torch_scatter.scatter(
            values, groups, dim=1, dim_size=output_size, reduce=reduction
        )
    elif backend == "coo":
        return torch_scatter.segment_coo(
            values, groups, dim_size=output_size, reduce=reduction
        )
    elif backend == "csr":
        return torch_scatter.segment_csr(values, groups, reduce=reduction)
    else:
        raise ValueError(
            f"Invalid value for the scatter backend ({backend}), "
            "should be one of 'torch', 'pyg', 'coo' or 'csr'."
        )


for backend in ["torch", "pyg", "coo", "csr"]:

    red = group_reduce(
        values=val,
        groups=gr_csr if backend == "csr" else gr_coo,
        reduction="sum",
        output_size=2,
        backend=backend,
    )

    # Compute an arbitrary scalar value out of our reduction:
    v = (red ** 2).sum(-1) + 0.0 * (val ** 2).sum(-1)
    # Gradient:
    g = torch.autograd.grad(v.sum(), [val], create_graph=True)[0]
    # Hessian:
    h = torch.zeros(B, D, D).type_as(val)
    for d in range(D):
        h[:, d, :] = torch.autograd.grad(g[:, d].sum(), [val], retain_graph=True)[0]

    print(backend, ":")
    print("Value:", v.detach().numpy())
    print("Grad :", g.detach().numpy())
    print("Hessian:")
    print(h.detach().numpy())
    print("--------------")

The output shows that torch_scatter.scatter and torch.scatter_reduce coincide on all derivatives, while the two segment_* implementations have a Null derivative at order 2:

torch :
Value: [5.]
Grad : [[2. 2. 4.]]
Hessian:
[[[2. 2. 0.]
  [2. 2. 0.]
  [0. 0. 2.]]]
--------------
pyg :
Value: [5.]
Grad : [[2. 2. 4.]]
Hessian:
[[[2. 2. 0.]
  [2. 2. 0.]
  [0. 0. 2.]]]
--------------
coo :
Value: [5.]
Grad : [[2. 2. 4.]]
Hessian:
[[[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]]
--------------
csr :
Value: [5.]
Grad : [[2. 2. 4.]]
Hessian:
[[[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]]
--------------

Is this expected behavior on your side?
Support for order-two derivatives would be especially useful to perform e.g. Newton optimization.

I'm sure that I could hack something with a torch.autograd.Function wrapper for my own use-case, but a proper fix would certainly be useful to other people. Unfortunately, I am not familiar enough with the PyTorch C++ API to fix e.g. segment_csr.cpp myself and write a Pull Request for this :-(

What do you think?

@jeanfeydy
Copy link
Author

As far as I can tell, the problem comes from the fact that gather_coo and gather_csr are not properly linked to their backward operators, which are respectively segment_coo and segment_csr with a "sum" reduction. I don't know why this is the case, since there definitely seems to be specific C++ code written for this in the repo.

In any case, here is a quick workaround:

import torch
import torch_scatter

# For the CSR operator:
class SumCSR(torch.autograd.Function):
    @staticmethod
    def forward(ctx, values, groups):
        ctx.save_for_backward(groups)
        return torch_scatter.segment_csr(values, groups, reduce="sum")

    @staticmethod
    def backward(ctx, grad_output):
        (groups,) = ctx.saved_tensors
        return GatherCSR.apply(grad_output, groups), None


class GatherCSR(torch.autograd.Function):
    @staticmethod
    def forward(ctx, values, groups):
        ctx.save_for_backward(groups)
        return torch_scatter.gather_csr(values, groups)

    @staticmethod
    def backward(ctx, grad_output):
        (groups,) = ctx.saved_tensors
        return SumCSR.apply(grad_output, groups), None


# For the COO operator:
class SumCOO(torch.autograd.Function):
    @staticmethod
    def forward(ctx, values, groups, dim_size):
        ctx.save_for_backward(groups)
        ctx.dim_size = dim_size
        return torch_scatter.segment_coo(
            values, groups, dim_size=dim_size, reduce="sum"
        )

    @staticmethod
    def backward(ctx, grad_output):
        (groups,) = ctx.saved_tensors
        return GatherCOO.apply(grad_output, groups, ctx.dim_size), None, None


class GatherCOO(torch.autograd.Function):
    @staticmethod
    def forward(ctx, values, groups, dim_size):
        ctx.save_for_backward(groups)
        ctx.dim_size = dim_size
        return torch_scatter.gather_coo(values, groups)

    @staticmethod
    def backward(ctx, grad_output):
        (groups,) = ctx.saved_tensors
        return SumCOO.apply(grad_output, groups, ctx.dim_size), None, None

Then, the code below runs just fine:

# Values:
val = torch.FloatTensor([[0, 1, 2]])
# Groups:
gr_coo = torch.LongTensor([[0, 0, 1]])
gr_csr = torch.LongTensor([[0, 2, 3]])

val.requires_grad = True
B, D = val.shape


def group_reduce(*, values, groups, reduction, output_size, backend):

    if backend == "torch":
        # Compatibility switch for PyTorch.scatter_reduce:
        if reduction == "max":
            reduction = "amax"
        return torch.scatter_reduce(
            values, 1, groups, reduction, output_size=output_size
        )
    elif backend == "pyg":
        return torch_scatter.scatter(
            values, groups, dim=1, dim_size=output_size, reduce=reduction
        )

    elif backend == "coo":
        return torch_scatter.segment_coo(
            values, groups, dim_size=output_size, reduce=reduction
        )
    elif backend == "my_coo":
        if reduction == "sum":
            return SumCOO.apply(values, groups, output_size)
        else:
            return torch_scatter.segment_coo(
                values, groups, dim_size=output_size, reduce=reduction
            )

    elif backend == "csr":
        return torch_scatter.segment_csr(values, groups, reduce=reduction)
    elif backend == "my_csr":
        if reduction == "sum":
            return SumCSR.apply(values, groups)
        else:
            return torch_scatter.segment_csr(values, groups, reduce=reduction)
    else:
        raise ValueError(
            f"Invalid value for the scatter backend ({backend}), "
            "should be one of 'torch', 'pyg', 'coo' or 'csr'."
        )


for backend in ["torch", "pyg", "coo", "my_coo", "csr", "my_csr"]:

    red = group_reduce(
        values=val,
        groups=gr_csr if "csr" in backend else gr_coo,
        reduction="sum",
        output_size=2,
        backend=backend,
    )

    # Compute an arbitrary scalar value out of our reduction:
    v = (red ** 2).sum(-1) + 0.0 * (val ** 2).sum(-1)
    # Gradient:
    g = torch.autograd.grad(v.sum(), [val], create_graph=True)[0]
    # Hessian:
    h = torch.zeros(B, D, D).type_as(val)
    for d in range(D):
        h[:, d, :] = torch.autograd.grad(g[:, d].sum(), [val], retain_graph=True)[0]

    print(backend, ":")
    print("Value:", v.detach().numpy())
    print("Grad :", g.detach().numpy())
    print("Hessian:")
    print(h.detach().numpy())
    print("--------------")

With the following output (notice that "torch", "pyg", "my_coo" and "my_csr" all coincide with each other):

torch :
Value: [5.]
Grad : [[2. 2. 4.]]
Hessian:
[[[2. 2. 0.]
  [2. 2. 0.]
  [0. 0. 2.]]]
--------------
pyg :
Value: [5.]
Grad : [[2. 2. 4.]]
Hessian:
[[[2. 2. 0.]
  [2. 2. 0.]
  [0. 0. 2.]]]
--------------
coo :
Value: [5.]
Grad : [[2. 2. 4.]]
Hessian:
[[[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]]
--------------
my_coo :
Value: [5.]
Grad : [[2. 2. 4.]]
Hessian:
[[[2. 2. 0.]
  [2. 2. 0.]
  [0. 0. 2.]]]
--------------
csr :
Value: [5.]
Grad : [[2. 2. 4.]]
Hessian:
[[[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]]
--------------
my_csr :
Value: [5.]
Grad : [[2. 2. 4.]]
Hessian:
[[[2. 2. 0.]
  [2. 2. 0.]
  [0. 0. 2.]]]
--------------

Best regards,
Jean

@rusty1s
Copy link
Owner

rusty1s commented May 21, 2022

Thanks @jeanfeydy for this insightful thread. I will need to take a closer look at this. It looks like PyTorch C++ backward linkage is indeed different from Python backward linkage, not really sure why. Glad that you already found a way to fix this on your end.

@github-actions
Copy link

This issue had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity. Is this issue already resolved?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants