From 783ed839bcb25117ef99c221a60da52d20070e9b Mon Sep 17 00:00:00 2001 From: Aliaksandr Yakutovich Date: Wed, 8 Jan 2020 17:23:12 +0000 Subject: [PATCH] Working version of trajectory visualizer based on bqplot --- aiida_datatypes_viewers.ipynb | 83 ++++++++++++++++++++++++++++ aiidalab_widgets_base/viewers.py | 94 +++++++++++++++++++++++++++++++- 2 files changed, 176 insertions(+), 1 deletion(-) diff --git a/aiida_datatypes_viewers.ipynb b/aiida_datatypes_viewers.ipynb index 3be943a41..cf38f5ac4 100644 --- a/aiida_datatypes_viewers.ipynb +++ b/aiida_datatypes_viewers.ipynb @@ -1,5 +1,16 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "%aiida" + ] + }, { "cell_type": "code", "execution_count": null, @@ -7,10 +18,82 @@ "outputs": [], "source": [ "from os import path\n", + "import numpy\n", "from aiida.plugins import DataFactory\n", "from aiidalab_widgets_base import viewer" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## TrajectoryData" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# visualize TrajectoryData\n", + "TrajectoryData = DataFactory('array.trajectory')\n", + "\n", + "traj = TrajectoryData()\n", + "\n", + "stepids = numpy.array([60, 70])\n", + "times = stepids * 0.01\n", + "cells = numpy.array([\n", + " [\n", + " [2., 0., 0.,], \n", + " [0., 2., 0.,],\n", + " [0., 0., 2.,]\n", + " ],\n", + " [\n", + " [3., 0., 0.,],\n", + " [0., 3., 0.,],\n", + " [0., 0., 3.,]\n", + " ]\n", + "])\n", + "symbols = ['H', 'O', 'H']\n", + "positions = numpy.array([\n", + " [\n", + " [0., 0., 0.],\n", + " [0.5, 0.5, 0.5],\n", + " [1.5, 1.5, 1.5]\n", + " ],\n", + " [\n", + " [0., 0., 0.],\n", + " [0.5, 0.5, 0.5],\n", + " [1.5, 1.5, 1.5]],\n", + "])\n", + "velocities = numpy.array([\n", + " [\n", + " [0., 0., 0.],\n", + " [0., 0., 0.],\n", + " [0., 0., 0.]\n", + " ],\n", + " [\n", + " [0.5, 0.5, 0.5],\n", + " [0.5, 0.5, 0.5],\n", + " [-0.5, -0.5, -0.5]\n", + " ]\n", + "])\n", + "\n", + "energy = numpy.array([1., 2.])\n", + "# I set the node\n", + "traj.set_trajectory(\n", + " stepids=stepids, cells=cells, symbols=symbols, positions=positions, times=times,\n", + " velocities=velocities\n", + ")\n", + "traj.set_array('energy', energy)\n", + "\n", + "traj.store()\n", + "vwr = viewer(traj)\n", + "#vwr = viewer(load_node('9bfc006b'))\n", + "display(vwr)" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/aiidalab_widgets_base/viewers.py b/aiidalab_widgets_base/viewers.py index 3daefddb4..ae01ff35e 100644 --- a/aiidalab_widgets_base/viewers.py +++ b/aiidalab_widgets_base/viewers.py @@ -190,7 +190,6 @@ def _on_atom_click(self, change=None): # pylint:disable=unused-argument """Update selection when clicked on atom.""" if 'atom1' not in self._viewer.picked.keys(): return # did not click on atom - index = self._viewer.picked['atom1']['index'] if index not in self.selection: @@ -310,6 +309,98 @@ def _update_structure_view(self, change): self._viewer.add_unitcell() # pylint: disable=no-member + +class TrajectoryDataViewer(ipw.VBox): + """Viewer class for TrajectoryData object.""" + + def __init__(self, trajectory, downloadable=True, **kwargs): + import bqplot.pyplot as plt + + # TrajectoryData object from AiiDA + self._trajectory = trajectory + self._frames = len(self._trajectory.get_stepids()) + self._structures = [trajectory.get_step_structure(i) for i in range(self._frames)] + + # Plot object. + self._plot = plt.figure(title="plot", layout={"width": "99%"}) + + # Trajectory navigator. + self._step_selector = ipw.IntSlider( + value=0, + min=0, + max=self._frames - 1, + ) + self._step_selector.observe(self.update_selection, names="value") + + # Property to plot. + self._property_selector = ipw.Dropdown( + options=trajectory.get_arraynames(), + value='energy', + description="Value to plot:", + ) + self._property_selector.observe(self.update_all, names="value") + + # Preparing scales. + x_data = self._trajectory.get_stepids() + + self._plot_line = plt.plot(x_data, self._trajectory.get_array(self._property_selector.value)[:len(x_data)]) + self._plot_circle = plt.scatter(x_data, self._trajectory.get_array(self._property_selector.value)[:len(x_data)]) + + on_plot_click_global = self.update_selection_from_plot + + def update_selection(self, change): + on_plot_click_global(self, change) + + self._plot_circle.on_element_click(update_selection) + + self._plot_select_circle = plt.scatter( + [self._trajectory.get_stepids()[self._step_selector.value]], + [self._trajectory.get_array(self._property_selector.value)[self._step_selector.value]], + stroke='red', + ) + + self._plot.axes[1].tick_format = "0.3g" + self.message = ipw.HTML() + + # Structure viewer. + self._struct_viewer = StructureDataViewer(self._structures, + downloadable=downloadable, + configure_view=False, + layout={"width": "50%"}) + + children = [ + ipw.HBox([self._struct_viewer, + ipw.VBox([self._plot, self._property_selector], layout={"width": "50%"})]), + self._step_selector, + self.message, + ] + + super().__init__(children, **kwargs) + + def update_all(self, change=None): # pylint: disable=unused-argument + """Update the data plot.""" + x_data = self._trajectory.get_stepids() + + self._plot_circle.x = x_data + self._plot_circle.y = self._trajectory.get_array(self._property_selector.value)[:len(x_data)] + + self._plot_line.x = x_data + self._plot_line.y = self._trajectory.get_array(self._property_selector.value)[:len(x_data)] + + self.update_selection() + + def update_selection(self, change=None): # pylint: disable=unused-argument + """Update selected point only.""" + self._struct_viewer.frame = self._step_selector.value + self._plot_select_circle.x = [self._trajectory.get_stepids()[self._step_selector.value]] + self._plot_select_circle.y = [ + self._trajectory.get_array(self._property_selector.value)[self._step_selector.value] + ] + + def update_selection_from_plot(self, _, selected_point): + self._step_selector.value = selected_point['data']['index'] + + class FolderDataViewer(ipw.VBox): """Viewer class for FolderData object. @@ -398,4 +489,5 @@ def __init__(self, bands, **kwargs): 'data.cif.CifData.': StructureDataViewer, 'data.folder.FolderData.': FolderDataViewer, 'data.array.bands.BandsData.': BandsDataViewer, + 'data.array.trajectory.TrajectoryData.': TrajectoryDataViewer, }