Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Edge explanations using dgl.nn.pytorch.explain.GNNExplainer #7507

Open
ayushnoori opened this issue Jul 6, 2024 · 1 comment
Open

Edge explanations using dgl.nn.pytorch.explain.GNNExplainer #7507

ayushnoori opened this issue Jul 6, 2024 · 1 comment

Comments

@ayushnoori
Copy link
Contributor

ayushnoori commented Jul 6, 2024

Hi DGL team, I’m kindly following up with my Slack messages.

I’m attempting to use dgl.nn.pytorch.explain.GNNExplainer to provide edge-level explanations for a heterogeneous graph transformer with dgl.nn.pytorch.conv.HGTConv layers. It seems, from the documentation, that “the required arguments of its forward function are graph, feat, and eweight (taken optionally). The feat argument is for input node features.

First, I've modified the HGTConv forward function to take the eweight argument as follows. May you please advise if this is correct?

Updated HGTConv code
import math
import types

from dgl import function as fn
from dgl.nn.pytorch import TypedLinear
from dgl.nn.pytorch import edge_softmax

def forward_exp(self, g, x, ntype, etype, *, presorted=False, eweight=None):
    """Forward computation.

    Parameters
    ----------
    g : DGLGraph
        The input graph.
    x : torch.Tensor
        A 2D tensor of node features. Shape: :math:`(|V|, D_{in})`.
    ntype : torch.Tensor
        An 1D integer tensor of node types. Shape: :math:`(|V|,)`.
    etype : torch.Tensor
        An 1D integer tensor of edge types. Shape: :math:`(|E|,)`.
    presorted : bool, optional
        Whether *both* the nodes and the edges of the input graph have been sorted by
        their types. Forward on pre-sorted graph may be faster. Graphs created by
        :func:`~dgl.to_homogeneous` automatically satisfy the condition.
        Also see :func:`~dgl.reorder_graph` for manually reordering the nodes and edges.

    Returns
    -------
    torch.Tensor
        New node features. Shape: :math:`(|V|, D_{head} * N_{head})`.
    """
    self.presorted = presorted
    if g.is_block:
        x_src = x
        x_dst = x[: g.num_dst_nodes()]
        srcntype = ntype
        dstntype = ntype[: g.num_dst_nodes()]
    else:
        x_src = x
        x_dst = x
        srcntype = ntype
        dstntype = ntype
    with g.local_scope():
        k = self.linear_k(x_src, srcntype, presorted).view(
            -1, self.num_heads, self.head_size
        )
        q = self.linear_q(x_dst, dstntype, presorted).view(
            -1, self.num_heads, self.head_size
        )
        v = self.linear_v(x_src, srcntype, presorted).view(
            -1, self.num_heads, self.head_size
        )
        g.srcdata["k"] = k
        g.dstdata["q"] = q
        g.srcdata["v"] = v
        g.edata["etype"] = etype
        g.apply_edges(self.message)
        g.edata["m"] = g.edata["m"] * edge_softmax(
            g, g.edata["a"]
        ).unsqueeze(-1)
        
        # Update for GNNExplainer
        if eweight is not None:
            # Multiply messages by edge weights
            eweight = eweight.view(g.edata['m'].shape[0], 1, 1)
            g.edata['m'] = g.edata['m'] * eweight
        g.update_all(fn.copy_e("m", "m"), fn.sum('m', 'h'))
        
        h = g.dstdata["h"].view(-1, self.num_heads * self.head_size)
        # target-specific aggregation
        h = self.drop(self.linear_a(h, dstntype, presorted))
        alpha = torch.sigmoid(self.skip[dstntype]).unsqueeze(-1)
        if x_dst.shape != h.shape:
            h = h * alpha + (x_dst @ self.residual_w) * (1 - alpha)
        else:
            h = h * alpha + x_dst * (1 - alpha)
        if self.use_norm:
            h = self.norm(h)
        return h

I then update the layers in my model with, for example:

# Replace the forward method
model.conv1.forward = types.MethodType(forward_exp, model.conv1)

Critically, it seems that the current implementation of GNNExplainer is limited to node and graph explanations via the explain_node() and explain_graph() functions, respectively, but this is not a limitation in the original paper. What I would need is a function like:

explain_edge(edge_id, graph, feat, **kwargs)

which also takes an edge_id argument.

May you please advise if it would be possible to use the current implementation of GNNExplainer in DGL to provide edge explanations. If so, I would appreciate your guidance on how to implement this method (is this in the roadmap already? should I start with the source code for explain_node()?); if not, please let me know if there are other explainability methods implemented in DGL that you could recommend instead for this task.

Thank you!

cc: @marinkaz; from the Slack conversation: @frozenbugs @jermainewang and team

Copy link

github-actions bot commented Aug 6, 2024

This issue has been automatically marked as stale due to lack of activity. It will be closed if no further activity occurs. Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant