Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Portilla-Simoncelli representation plot #223

Open
billbrod opened this issue Sep 21, 2023 · 0 comments
Open

Improve Portilla-Simoncelli representation plot #223

billbrod opened this issue Sep 21, 2023 · 0 comments
Labels
enhancement New feature or request

Comments

@billbrod
Copy link
Collaborator

As I was refactoring the PS code, I spent a bit of time trying to make plot_representation look more informative. This is not part of that PR, so I'm pasting the functionality here, in case I return to it. The following needs to be double-checked to make sure it's accurate, that the ordering is correct, and cleaning up the titles / adding helpful labels (and I think update_plot would need to be updated as well):

    def plot_representation(
            self,
            data: Union[Tensor, OrderedDict],
            ax: Optional[mpl.axes.Axes] = None,
            figsize: Tuple[float, float] = (15, 15),
            ylim: Optional[Tuple[float, float]] = None,
            batch_idx: int = 0,
            title: Optional[str] = None,
    ) -> Tuple[mpl.figure.Figure, List[mpl.axes.Axes]]:

        r"""Plot the representation in a human viewable format -- stem
        plots with data separated out by statistic type.

        Currently, this averages over all channels in the representation.

        Parameters
        ----------
        data :
            The data to show on the plot. Else, should look like the output of
            ``self.forward(img)``, with the exact same structure (e.g., as
            returned by ``metamer.representation_error()`` or another instance
            of this class).
        ax :
            Axes where we will plot the data. If an ``mpl.axes.Axes``, will
            subdivide into 9 new axes. If None, we create a new figure.
        figsize :
            The size of the figure. Ignored if ax is not None.
        ylim :
            The ylimits of the plot.
        batch_idx :
            Which index to take from the batch dimension (the first one)
        title : string
            Title for the plot

        Returns
        -------
        fig:
            Figure containing the plot
        axes:
            List of axes containing the plots

        """
        n_rows = self.n_scales + 2
        n_cols = 5

        # pick the batch_idx we want, and average over channels.
        rep = self.convert_to_dict(data[batch_idx].mean(0))
        data = self._representation_for_plotting(rep)

        # Set up grid spec
        if ax is None:
            # we add 2 to order because we're adding one to get the
            # number of orientations and then another one to add an
            # extra column for the mean luminance plot
            fig = plt.figure(figsize=figsize)
            gs = mpl.gridspec.GridSpec(n_rows, n_cols, fig)
        else:
            # warnings.warn("ax is not None, so we're ignoring figsize...")
            # want to make sure the axis we're taking over is basically invisible.
            ax = clean_up_axes(
                ax, False, ["top", "right", "bottom", "left"], ["x", "y"]
            )
            gs = ax.get_subplotspec().subgridspec(n_rows, n_cols)
            fig = ax.figure

        # mapping between data and column
        cols = {k: i-1 for i, k in enumerate(data.keys())}

        # plot data
        axes = []
        for k, v in data.items():
            # handle this one differently than the rest, because it has no
            # muli-scale component
            if k == 'pixels+var_highpass':
                ax = fig.add_subplot(gs[0, 0])
                ax = clean_stem_plot(to_numpy(v), ax, k, ylim=ylim)
            else:
                # scales is along the last dimension, and that's what we want
                # to iterate through
                for sc in range(v.shape[-1]):
                    print(cols[k], n_rows - sc - 1)
                    ax = fig.add_subplot(gs[n_rows - sc - 1, cols[k]])
                    ax = clean_stem_plot(to_numpy(v[..., sc]).flatten(), ax, k, ylim=ylim)
            axes.append(ax)

        if title is not None:
            fig.suptitle(title)

        return fig, axes

    def _representation_for_plotting(self, rep: OrderedDict) -> OrderedDict:
        r"""Converts the data into a dictionary representation that is more convenient for plotting.

        Intended as a helper function for plot_representation.

        """
        if rep['skew_reconstructed'].ndim > 1:
            raise ValueError("Currently, only know how to plot single batch and channel at a time! "
                             "Select and/or average over those dimensions")
        data = OrderedDict()
        data["pixels+var_highpass"] = torch.stack(list(rep.pop("pixel_statistics").values()) +
                                                  [rep.pop("var_highpass_residual")])
        data["reconstructed_image_stats"] = torch.stack(
            (
                rep.pop("std_reconstructed").pow(2),
                rep.pop("skew_reconstructed"),
                rep.pop("kurtosis_reconstructed"),
                rep.pop("auto_correlation_reconstructed").norm(p=2, dim=(0, 1))
            ),
            # want scales to be on the last axis
            0,
        )

        # want scales to be on the last axis
        data["auto_correlation_magnitude"] = torch.norm(rep.pop("auto_correlation_magnitude"),
                                                        p=2, dim=(0, 1)).t()
        # add the rest of the keys
        data.update(rep)

        return data
@billbrod billbrod added the enhancement New feature or request label Sep 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant