Skip to content

Commit

Permalink
Working version of trajectory visualizer based on bqplot
Browse files Browse the repository at this point in the history
  • Loading branch information
yakutovicha committed Feb 5, 2020
1 parent 6b2018d commit 783ed83
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 1 deletion.
83 changes: 83 additions & 0 deletions aiida_datatypes_viewers.ipynb
Original file line number Diff line number Diff line change
@@ -1,16 +1,99 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"%aiida"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from os import path\n",
"import numpy\n",
"from aiida.plugins import DataFactory\n",
"from aiidalab_widgets_base import viewer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## TrajectoryData"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# visualize TrajectoryData\n",
"TrajectoryData = DataFactory('array.trajectory')\n",
"\n",
"traj = TrajectoryData()\n",
"\n",
"stepids = numpy.array([60, 70])\n",
"times = stepids * 0.01\n",
"cells = numpy.array([\n",
" [\n",
" [2., 0., 0.,], \n",
" [0., 2., 0.,],\n",
" [0., 0., 2.,]\n",
" ],\n",
" [\n",
" [3., 0., 0.,],\n",
" [0., 3., 0.,],\n",
" [0., 0., 3.,]\n",
" ]\n",
"])\n",
"symbols = ['H', 'O', 'H']\n",
"positions = numpy.array([\n",
" [\n",
" [0., 0., 0.],\n",
" [0.5, 0.5, 0.5],\n",
" [1.5, 1.5, 1.5]\n",
" ],\n",
" [\n",
" [0., 0., 0.],\n",
" [0.5, 0.5, 0.5],\n",
" [1.5, 1.5, 1.5]],\n",
"])\n",
"velocities = numpy.array([\n",
" [\n",
" [0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.]\n",
" ],\n",
" [\n",
" [0.5, 0.5, 0.5],\n",
" [0.5, 0.5, 0.5],\n",
" [-0.5, -0.5, -0.5]\n",
" ]\n",
"])\n",
"\n",
"energy = numpy.array([1., 2.])\n",
"# I set the node\n",
"traj.set_trajectory(\n",
" stepids=stepids, cells=cells, symbols=symbols, positions=positions, times=times,\n",
" velocities=velocities\n",
")\n",
"traj.set_array('energy', energy)\n",
"\n",
"traj.store()\n",
"vwr = viewer(traj)\n",
"#vwr = viewer(load_node('9bfc006b'))\n",
"display(vwr)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
94 changes: 93 additions & 1 deletion aiidalab_widgets_base/viewers.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ def _on_atom_click(self, change=None): # pylint:disable=unused-argument
"""Update selection when clicked on atom."""
if 'atom1' not in self._viewer.picked.keys():
return # did not click on atom

index = self._viewer.picked['atom1']['index']

if index not in self.selection:
Expand Down Expand Up @@ -310,6 +309,98 @@ def _update_structure_view(self, change):
self._viewer.add_unitcell() # pylint: disable=no-member



class TrajectoryDataViewer(ipw.VBox):
"""Viewer class for TrajectoryData object."""

def __init__(self, trajectory, downloadable=True, **kwargs):
import bqplot.pyplot as plt

# TrajectoryData object from AiiDA
self._trajectory = trajectory
self._frames = len(self._trajectory.get_stepids())
self._structures = [trajectory.get_step_structure(i) for i in range(self._frames)]

# Plot object.
self._plot = plt.figure(title="plot", layout={"width": "99%"})

# Trajectory navigator.
self._step_selector = ipw.IntSlider(
value=0,
min=0,
max=self._frames - 1,
)
self._step_selector.observe(self.update_selection, names="value")

# Property to plot.
self._property_selector = ipw.Dropdown(
options=trajectory.get_arraynames(),
value='energy',
description="Value to plot:",
)
self._property_selector.observe(self.update_all, names="value")

# Preparing scales.
x_data = self._trajectory.get_stepids()

self._plot_line = plt.plot(x_data, self._trajectory.get_array(self._property_selector.value)[:len(x_data)])
self._plot_circle = plt.scatter(x_data, self._trajectory.get_array(self._property_selector.value)[:len(x_data)])

on_plot_click_global = self.update_selection_from_plot

def update_selection(self, change):
on_plot_click_global(self, change)

self._plot_circle.on_element_click(update_selection)

self._plot_select_circle = plt.scatter(
[self._trajectory.get_stepids()[self._step_selector.value]],
[self._trajectory.get_array(self._property_selector.value)[self._step_selector.value]],
stroke='red',
)

self._plot.axes[1].tick_format = "0.3g"
self.message = ipw.HTML()

# Structure viewer.
self._struct_viewer = StructureDataViewer(self._structures,
downloadable=downloadable,
configure_view=False,
layout={"width": "50%"})

children = [
ipw.HBox([self._struct_viewer,
ipw.VBox([self._plot, self._property_selector], layout={"width": "50%"})]),
self._step_selector,
self.message,
]

super().__init__(children, **kwargs)

def update_all(self, change=None): # pylint: disable=unused-argument
"""Update the data plot."""
x_data = self._trajectory.get_stepids()

self._plot_circle.x = x_data
self._plot_circle.y = self._trajectory.get_array(self._property_selector.value)[:len(x_data)]

self._plot_line.x = x_data
self._plot_line.y = self._trajectory.get_array(self._property_selector.value)[:len(x_data)]

self.update_selection()

def update_selection(self, change=None): # pylint: disable=unused-argument
"""Update selected point only."""
self._struct_viewer.frame = self._step_selector.value
self._plot_select_circle.x = [self._trajectory.get_stepids()[self._step_selector.value]]
self._plot_select_circle.y = [
self._trajectory.get_array(self._property_selector.value)[self._step_selector.value]
]

def update_selection_from_plot(self, _, selected_point):
self._step_selector.value = selected_point['data']['index']


class FolderDataViewer(ipw.VBox):
"""Viewer class for FolderData object.
Expand Down Expand Up @@ -398,4 +489,5 @@ def __init__(self, bands, **kwargs):
'data.cif.CifData.': StructureDataViewer,
'data.folder.FolderData.': FolderDataViewer,
'data.array.bands.BandsData.': BandsDataViewer,
'data.array.trajectory.TrajectoryData.': TrajectoryDataViewer,
}

0 comments on commit 783ed83

Please sign in to comment.