From 10654e576d531d4801a75750b1d36e1df48fd76a Mon Sep 17 00:00:00 2001 From: Edan Bainglass Date: Mon, 28 Oct 2024 09:02:44 +0000 Subject: [PATCH] Fix `structure.kind`-related bugs --- .../configuration/advanced/hubbard/model.py | 13 ++++++---- .../configuration/advanced/pseudos/model.py | 8 +++--- .../configuration/advanced/pseudos/pseudos.py | 26 ++++++++----------- tests/test_pseudo.py | 6 ++--- 4 files changed, 25 insertions(+), 28 deletions(-) diff --git a/src/aiidalab_qe/app/configuration/advanced/hubbard/model.py b/src/aiidalab_qe/app/configuration/advanced/hubbard/model.py index 997fc8719..847f81883 100644 --- a/src/aiidalab_qe/app/configuration/advanced/hubbard/model.py +++ b/src/aiidalab_qe/app/configuration/advanced/hubbard/model.py @@ -102,7 +102,7 @@ def _update_defaults(self, which): ), [ Element(symbol) - for symbol in self.input_structure.get_kind_names() + for symbol in self.input_structure.get_symbols_set() ], ) ] @@ -124,13 +124,16 @@ def _get_default_eigenvalues(self): return deepcopy(self._defaults["eigenvalues"]) def _get_labels(self): - symbols = self.input_structure.get_kind_names() hubbard_manifold_list = [ - self._get_manifold(Element(symbol)) for symbol in symbols + self._get_manifold(Element(kind.symbol)) + for kind in self.input_structure.kinds ] return [ - f"{symbol} - {manifold}" - for symbol, manifold in zip(symbols, hubbard_manifold_list) + f"{kind_name} - {manifold}" + for kind_name, manifold in zip( + self.input_structure.get_kind_names(), + hubbard_manifold_list, + ) ] def _get_manifold(self, element): diff --git a/src/aiidalab_qe/app/configuration/advanced/pseudos/model.py b/src/aiidalab_qe/app/configuration/advanced/pseudos/model.py index 266b3e573..0d407cbcb 100644 --- a/src/aiidalab_qe/app/configuration/advanced/pseudos/model.py +++ b/src/aiidalab_qe/app/configuration/advanced/pseudos/model.py @@ -169,14 +169,12 @@ def update_default_cutoffs(self): """ else: - symbols = ( - self.input_structure.get_kind_names() if self.input_structure else [] - ) + kinds = self.input_structure.kinds if self.input_structure else [] ecutwfc_list = [] ecutrho_list = [] - for symbol in symbols: - cutoff = cutoff_dict.get(symbol, {}) + for kind in kinds: + cutoff = cutoff_dict.get(kind.symbol, {}) ecutrho, ecutwfc = ( U.Quantity(v, current_unit).to("Ry").to_tuple()[0] for v in cutoff.values() diff --git a/src/aiidalab_qe/app/configuration/advanced/pseudos/pseudos.py b/src/aiidalab_qe/app/configuration/advanced/pseudos/pseudos.py index 3acc85401..68ca5f645 100644 --- a/src/aiidalab_qe/app/configuration/advanced/pseudos/pseudos.py +++ b/src/aiidalab_qe/app/configuration/advanced/pseudos/pseudos.py @@ -287,22 +287,18 @@ def _build_setter_widgets(self): children = [] - elements = ( - self._model.input_structure.get_kind_names() - if self._model.input_structure - else [] - ) + kinds = self._model.input_structure.kinds if self._model.input_structure else [] - for index, element in enumerate(elements): - upload_widget = PseudoUploadWidget(element=element) + for index, kind in enumerate(kinds): + upload_widget = PseudoUploadWidget(kind_name=kind.name) pseudo_link = ipw.link( (self._model, "dictionary"), (upload_widget, "pseudo"), [ - lambda d, element=element: orm.load_node(d.get(element)), - lambda v, element=element: { + lambda d, symbol=kind.symbol: orm.load_node(d.get(symbol)), + lambda v, symbol=kind.symbol: { **self._model.dictionary, - element: v.uuid, + symbol: v.uuid, }, ], ) @@ -345,13 +341,13 @@ class PseudoUploadWidget(ipw.HBox): cutoffs = tl.List(tl.Float(), []) error_message = tl.Unicode(allow_none=True) - def __init__(self, element, **kwargs): + def __init__(self, kind_name, **kwargs): super().__init__( children=[LoadingWidget("Loading pseudopotential uploader")], **kwargs, ) - self.element = element + self.kind_name = kind_name self.rendered = False @@ -363,7 +359,7 @@ def render(self): description="Upload", multiple=False, ) - self.pseudo_text = ipw.Text(description=self.element) + self.pseudo_text = ipw.Text(description=self.kind_name) self.file_upload.observe(self._on_file_upload, "value") cutoffs_message_template = """ @@ -413,9 +409,9 @@ def _on_file_upload(self, change=None): self.pseudo.store() # check if element is matched with the pseudo - element = "".join([i for i in self.element if not i.isdigit()]) + element = "".join([i for i in self.kind_name if not i.isdigit()]) if element != self.pseudo.element: - self.error_message = f"""
ERROR: Element {self.element} is not matched with the pseudo {self.pseudo.element}
""" + self.error_message = f"""
ERROR: Element {self.kind_name} is not matched with the pseudo {self.pseudo.element}
""" self._reset() else: self.pseudo_text.value = filename diff --git a/tests/test_pseudo.py b/tests/test_pseudo.py index 36c72852b..3c193c6f1 100644 --- a/tests/test_pseudo.py +++ b/tests/test_pseudo.py @@ -228,7 +228,7 @@ def test_pseudo_upload_widget(generate_upf_data): # the widget initialize with the pseudo as input to mock how it will # be used in PseudoSetter when the pseudo family is set. old_pseudo = generate_upf_data("O", "O_old.upf") - w = PseudoUploadWidget(element="O1") + w = PseudoUploadWidget(kind_name="O1") w.pseudo = old_pseudo w.cutoffs = [30, 240] w.render() @@ -236,7 +236,7 @@ def test_pseudo_upload_widget(generate_upf_data): message = "Recommended ecutwfc: {ecutwfc} Ry ecutrho: {ecutrho} Ry" assert w.pseudo.filename == "O_old.upf" - assert w.element == "O1" + assert w.kind_name == "O1" assert message.format(ecutwfc=30.0, ecutrho=240.0) in w.cutoff_message.value assert w.error_message is None @@ -253,7 +253,7 @@ def test_pseudo_upload_widget(generate_upf_data): ) assert w.pseudo.filename == "O_new.upf" - assert w.element == "O1" + assert w.kind_name == "O1" assert w.error_message is None # test upload a invalid pseudo of other element