You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
As I was refactoring the PS code, I spent a bit of time trying to make plot_representation look more informative. This is not part of that PR, so I'm pasting the functionality here, in case I return to it. The following needs to be double-checked to make sure it's accurate, that the ordering is correct, and cleaning up the titles / adding helpful labels (and I think update_plot would need to be updated as well):
defplot_representation(
self,
data: Union[Tensor, OrderedDict],
ax: Optional[mpl.axes.Axes] =None,
figsize: Tuple[float, float] = (15, 15),
ylim: Optional[Tuple[float, float]] =None,
batch_idx: int=0,
title: Optional[str] =None,
) ->Tuple[mpl.figure.Figure, List[mpl.axes.Axes]]:
r"""Plot the representation in a human viewable format -- stem plots with data separated out by statistic type. Currently, this averages over all channels in the representation. Parameters ---------- data : The data to show on the plot. Else, should look like the output of ``self.forward(img)``, with the exact same structure (e.g., as returned by ``metamer.representation_error()`` or another instance of this class). ax : Axes where we will plot the data. If an ``mpl.axes.Axes``, will subdivide into 9 new axes. If None, we create a new figure. figsize : The size of the figure. Ignored if ax is not None. ylim : The ylimits of the plot. batch_idx : Which index to take from the batch dimension (the first one) title : string Title for the plot Returns ------- fig: Figure containing the plot axes: List of axes containing the plots """n_rows=self.n_scales+2n_cols=5# pick the batch_idx we want, and average over channels.rep=self.convert_to_dict(data[batch_idx].mean(0))
data=self._representation_for_plotting(rep)
# Set up grid specifaxisNone:
# we add 2 to order because we're adding one to get the# number of orientations and then another one to add an# extra column for the mean luminance plotfig=plt.figure(figsize=figsize)
gs=mpl.gridspec.GridSpec(n_rows, n_cols, fig)
else:
# warnings.warn("ax is not None, so we're ignoring figsize...")# want to make sure the axis we're taking over is basically invisible.ax=clean_up_axes(
ax, False, ["top", "right", "bottom", "left"], ["x", "y"]
)
gs=ax.get_subplotspec().subgridspec(n_rows, n_cols)
fig=ax.figure# mapping between data and columncols= {k: i-1fori, kinenumerate(data.keys())}
# plot dataaxes= []
fork, vindata.items():
# handle this one differently than the rest, because it has no# muli-scale componentifk=='pixels+var_highpass':
ax=fig.add_subplot(gs[0, 0])
ax=clean_stem_plot(to_numpy(v), ax, k, ylim=ylim)
else:
# scales is along the last dimension, and that's what we want# to iterate throughforscinrange(v.shape[-1]):
print(cols[k], n_rows-sc-1)
ax=fig.add_subplot(gs[n_rows-sc-1, cols[k]])
ax=clean_stem_plot(to_numpy(v[..., sc]).flatten(), ax, k, ylim=ylim)
axes.append(ax)
iftitleisnotNone:
fig.suptitle(title)
returnfig, axesdef_representation_for_plotting(self, rep: OrderedDict) ->OrderedDict:
r"""Converts the data into a dictionary representation that is more convenient for plotting. Intended as a helper function for plot_representation. """ifrep['skew_reconstructed'].ndim>1:
raiseValueError("Currently, only know how to plot single batch and channel at a time! ""Select and/or average over those dimensions")
data=OrderedDict()
data["pixels+var_highpass"] =torch.stack(list(rep.pop("pixel_statistics").values()) +
[rep.pop("var_highpass_residual")])
data["reconstructed_image_stats"] =torch.stack(
(
rep.pop("std_reconstructed").pow(2),
rep.pop("skew_reconstructed"),
rep.pop("kurtosis_reconstructed"),
rep.pop("auto_correlation_reconstructed").norm(p=2, dim=(0, 1))
),
# want scales to be on the last axis0,
)
# want scales to be on the last axisdata["auto_correlation_magnitude"] =torch.norm(rep.pop("auto_correlation_magnitude"),
p=2, dim=(0, 1)).t()
# add the rest of the keysdata.update(rep)
returndata
The text was updated successfully, but these errors were encountered:
As I was refactoring the PS code, I spent a bit of time trying to make
plot_representation
look more informative. This is not part of that PR, so I'm pasting the functionality here, in case I return to it. The following needs to be double-checked to make sure it's accurate, that the ordering is correct, and cleaning up the titles / adding helpful labels (and I thinkupdate_plot
would need to be updated as well):The text was updated successfully, but these errors were encountered: