From f1796813cd35ffa1ca3ce5644f08ceb4714bfdc0 Mon Sep 17 00:00:00 2001 From: Oliver Backhouse Date: Sat, 21 Sep 2024 11:06:25 +0100 Subject: [PATCH] linting --- ebcc/backend/__init__.py | 4 ++-- ebcc/backend/_ctf.py | 32 ++++++++++++++++---------------- ebcc/backend/_jax.py | 3 ++- ebcc/backend/_numpy.py | 4 ++-- ebcc/backend/_tensorflow.py | 10 +++++----- ebcc/cc/base.py | 8 +++++--- ebcc/cc/gebcc.py | 6 ++++-- ebcc/cc/uebcc.py | 4 +++- ebcc/core/damping.py | 12 +++++++----- ebcc/eom/base.py | 2 +- ebcc/eom/geom.py | 8 ++++---- ebcc/eom/reom.py | 4 ++-- ebcc/eom/ueom.py | 8 ++++---- ebcc/ham/cderis.py | 4 +++- ebcc/ham/elbos.py | 2 +- ebcc/ham/space.py | 10 ++++++---- ebcc/opt/base.py | 2 +- ebcc/opt/gbrueckner.py | 3 +-- ebcc/opt/rbrueckner.py | 3 +-- ebcc/opt/ubrueckner.py | 2 +- ebcc/util/__init__.py | 10 +++++++++- ebcc/util/misc.py | 11 ++++++++--- ebcc/util/permutations.py | 17 +++++++++++------ pyproject.toml | 4 ++++ 24 files changed, 103 insertions(+), 70 deletions(-) diff --git a/ebcc/backend/__init__.py b/ebcc/backend/__init__.py index 775462bb..be92a787 100644 --- a/ebcc/backend/__init__.py +++ b/ebcc/backend/__init__.py @@ -94,9 +94,9 @@ def _put( # TODO MPI has to be manually managed here if isinstance(indices, tuple): indices_flat = np.ravel_multi_index(indices, array.shape) - array.write(indices_flat, values) + array.write(indices_flat, values) # type: ignore else: - array.write(indices, values) + array.write(indices, values) # type: ignore return array else: raise NotImplementedError(f"Backend {BACKEND} _put not implemented.") diff --git a/ebcc/backend/_ctf.py b/ebcc/backend/_ctf.py index 43eca709..af3a57a1 100644 --- a/ebcc/backend/_ctf.py +++ b/ebcc/backend/_ctf.py @@ -4,7 +4,6 @@ import ctf import numpy import opt_einsum -import warnings def __getattr__(name): @@ -16,17 +15,17 @@ class FakeLinalg: """Fake linalg module for CTF.""" def __getattr__(self, name): + """Get the attribute from CTF's linalg module.""" return getattr(ctf.linalg, name) - def eigh(self, a): + def eigh(self, a): # noqa: D102 # TODO Need to determine if SCALAPACK is available w, v = numpy.linalg.eigh(a.to_nparray()) w = ctf.astensor(w) v = ctf.astensor(v) return w, v - #norm = ctf.norm - def norm(self, a, ord=None): + def norm(self, a, ord=None): # noqa: D102 return ctf.norm(a, ord=ord) @@ -40,48 +39,49 @@ def norm(self, a, ord=None): _array = ctf.array -def array(obj, **kwargs): + +def array(obj, **kwargs): # noqa: D103 if isinstance(obj, ctf.tensor): return obj return _array(numpy.asarray(obj), **kwargs) -def astype(obj, dtype): +def astype(obj, dtype): # noqa: D103 return obj.astype(dtype) -def zeros_like(obj): +def zeros_like(obj): # noqa: D103 return ctf.zeros(obj.shape).astype(obj.dtype) -def ones_like(obj): +def ones_like(obj): # noqa: D103 return ctf.ones(obj.shape).astype(obj.dtype) -def arange(start, stop=None, step=1, dtype=None): +def arange(start, stop=None, step=1, dtype=None): # noqa: D103 if stop is None: stop = start start = 0 return ctf.arange(start, stop, step=step, dtype=dtype) -def argmin(obj): +def argmin(obj): # noqa: D103 return ctf.to_nparray(obj).argmin() -def argmax(obj): +def argmax(obj): # noqa: D103 return ctf.to_nparray(obj).argmax() -def bitwise_and(a, b): +def bitwise_and(a, b): # noqa: D103 return a * b -def bitwise_not(a): +def bitwise_not(a): # noqa: D103 return ones_like(a) - a -def concatenate(arrays, axis=None): +def concatenate(arrays, axis=None): # noqa: D103 if axis is None: axis = 0 if axis < 0: @@ -106,7 +106,7 @@ def concatenate(arrays, axis=None): return result -def _block_recursive(arrays, max_depth, depth=0): +def _block_recursive(arrays, max_depth, depth=0): # noqa: D103 if depth < max_depth: arrs = [_block_recursive(arr, max_depth, depth + 1) for arr in arrays] return concatenate(arrs, axis=-(max_depth - depth)) @@ -114,7 +114,7 @@ def _block_recursive(arrays, max_depth, depth=0): return arrays -def block(arrays): +def block(arrays): # noqa: D103 def _get_max_depth(arrays): if isinstance(arrays, list): return 1 + max([_get_max_depth(arr) for arr in arrays]) diff --git a/ebcc/backend/_jax.py b/ebcc/backend/_jax.py index 90a11c8f..6ccdcf64 100644 --- a/ebcc/backend/_jax.py +++ b/ebcc/backend/_jax.py @@ -15,7 +15,8 @@ def __getattr__(name): _jax_ix_ = jax.numpy.ix_ -def ix_(*args): + +def ix_(*args): # noqa: D103 args_ = [] for arg in args: if isinstance(arg, jax.numpy.ndarray) and arg.dtype == jax.numpy.bool_: diff --git a/ebcc/backend/_numpy.py b/ebcc/backend/_numpy.py index 644a61a8..ddc771ac 100644 --- a/ebcc/backend/_numpy.py +++ b/ebcc/backend/_numpy.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from numpy import generic - from numpy.typing import DType, NDArray + from numpy.typing import NDArray def __getattr__(name: str) -> object: @@ -16,7 +16,7 @@ def __getattr__(name: str) -> object: return getattr(numpy, name) -def astype(obj: NDArray[generic], dtype: DType) -> NDArray[generic]: +def astype(obj: NDArray[generic], dtype: type[generic]) -> NDArray[generic]: """Cast the array to the specified type. Args: diff --git a/ebcc/backend/_tensorflow.py b/ebcc/backend/_tensorflow.py index 9e6cda0b..499c0d5c 100644 --- a/ebcc/backend/_tensorflow.py +++ b/ebcc/backend/_tensorflow.py @@ -15,11 +15,11 @@ def __getattr__(name): return getattr(tensorflow.experimental.numpy, name) -def astype(obj, dtype): +def astype(obj, dtype): # noqa: D103 return obj.astype(dtype) -def _block_recursive(arrays, max_depth, depth=0): +def _block_recursive(arrays, max_depth, depth=0): # noqa: D103 if depth < max_depth: arrs = [_block_recursive(arr, max_depth, depth + 1) for arr in arrays] return tensorflow.experimental.numpy.concatenate(arrs, axis=-(max_depth - depth)) @@ -27,7 +27,7 @@ def _block_recursive(arrays, max_depth, depth=0): return arrays -def block(arrays): +def block(arrays): # noqa: D103 def _get_max_depth(arrays): if isinstance(arrays, list): return 1 + max([_get_max_depth(arr) for arr in arrays]) @@ -36,7 +36,7 @@ def _get_max_depth(arrays): return _block_recursive(arrays, _get_max_depth(arrays)) -def ravel_multi_index(multi_index, dims, mode="raise", order="C"): +def ravel_multi_index(multi_index, dims, mode="raise", order="C"): # noqa: D103 if mode != "raise": raise NotImplementedError("Only 'raise' mode is implemented") if order != "C": @@ -51,7 +51,7 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"): return flat_index -def indices(dimensions, dtype=tf.int32, sparse=False): +def indices(dimensions, dtype=tf.int32, sparse=False): # noqa: D103 # Generate a range of indices for each dimension ranges = [tf.range(dim, dtype=dtype) for dim in dimensions] diff --git a/ebcc/cc/base.py b/ebcc/cc/base.py index f1bedf5c..e5f7ac24 100644 --- a/ebcc/cc/base.py +++ b/ebcc/cc/base.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from ebcc import default_log, init_logging from ebcc import numpy as np @@ -256,7 +256,9 @@ def kernel(self, eris: Optional[ERIsInputType] = None) -> float: vector = self.amplitudes_to_vector(amplitudes) vector = diis.update(vector) amplitudes = self.vector_to_amplitudes(vector) - dt = np.linalg.norm(np.abs(vector - self.amplitudes_to_vector(amplitudes_prev)), ord=np.inf) + dt = np.linalg.norm( + np.abs(vector - self.amplitudes_to_vector(amplitudes_prev)), ord=np.inf + ) # Update the energy and calculate change: e_prev = e_cc @@ -951,7 +953,7 @@ def const(self) -> float: """ if self.options.shift: assert self.omega is not None - return np.ravel(util.einsum("I,I->", self.omega, self.xi**2.0))[0] + return cast(float, np.ravel(util.einsum("I,I->", self.omega, self.xi**2.0))[0]) return 0.0 @property diff --git a/ebcc/cc/gebcc.py b/ebcc/cc/gebcc.py index 756dce0a..0803ccc2 100644 --- a/ebcc/cc/gebcc.py +++ b/ebcc/cc/gebcc.py @@ -192,7 +192,8 @@ def from_uebcc(cls, ucc: UEBCC) -> GEBCC: amplitudes[name] = _put( amplitudes[name], mask, - amplitudes[name][mask] + np.transpose(amp, transpose) * sign, + amplitudes[name][mask] + + np.transpose(amp, transpose) * sign, ) done.add(combn) @@ -234,7 +235,8 @@ def from_uebcc(cls, ucc: UEBCC) -> GEBCC: amplitudes[name] = _put( amplitudes[name], mask, - amplitudes[name][mask] + np.transpose(amp, transpose) * sign, + amplitudes[name][mask] + + np.transpose(amp, transpose) * sign, ) done.add(combn) diff --git a/ebcc/cc/uebcc.py b/ebcc/cc/uebcc.py index ab03a595..32bed76d 100644 --- a/ebcc/cc/uebcc.py +++ b/ebcc/cc/uebcc.py @@ -243,7 +243,9 @@ def init_amps(self, eris: Optional[ERIsInputType] = None) -> Namespace[SpinArray elif n == 2: comb_t = comb[0] + comb[2] + comb[1] + comb[3] key_t = key[0] + key[2] + key[1] + key[3] - tn[comb] = np.transpose(eris[comb_t][key_t], (0, 2, 1, 3)) / self.energy_sum(key, comb) + tn[comb] = np.transpose(eris[comb_t][key_t], (0, 2, 1, 3)) / self.energy_sum( + key, comb + ) if comb in ("aaaa", "bbbb"): # TODO generalise: tn[comb] = (tn[comb] - np.transpose(tn[comb], (1, 0, 2, 3))) * 0.5 diff --git a/ebcc/core/damping.py b/ebcc/core/damping.py index 4a46d852..d117deb4 100644 --- a/ebcc/core/damping.py +++ b/ebcc/core/damping.py @@ -112,10 +112,12 @@ def update(self, x: NDArray[T], xerr: Optional[NDArray[T]] = None) -> NDArray[T] # this looks crazy, but it's just updating the `self._index`th row and # column with the new errors, it's just done this way to avoid using # calls to `__setitem__` in immutable backends - m_i = np.array([ - np.ravel(np.dot(np.conj(np.ravel(x1)), np.ravel(self._errors[i])))[0] - for i in range(nd) - ]) + m_i = np.array( + [ + np.ravel(np.dot(np.conj(np.ravel(x1)), np.ravel(self._errors[i])))[0] + for i in range(nd) + ] + ) m_i = np.concatenate([np.array([1.0]), m_i, np.zeros(self.space - nd)]) m_i = np.reshape(m_i, (-1, 1)) m_j = np.conj(np.transpose(m_i)) @@ -166,7 +168,7 @@ def extrapolate(self, nd: Optional[int] = None) -> NDArray[T]: if np.any(np.abs(w) < 1e-14): mask = np.abs(w) > 1e-14 w, v = w[mask], v[:, mask] - c = util.einsum("pi,qi,i,q->p", v, np.conj(v), w ** -1, g) + c = util.einsum("pi,qi,i,q->p", v, np.conj(v), w**-1.0, g) # Construct the new vector xnew: NDArray[T] = np.zeros_like(self._arrays[0]) diff --git a/ebcc/eom/base.py b/ebcc/eom/base.py index f91df55c..7def94c5 100644 --- a/ebcc/eom/base.py +++ b/ebcc/eom/base.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: from typing import Any, Callable, Optional - from numpy import float64, int64 + from numpy import float64 from numpy.typing import NDArray from ebcc.cc.base import BaseEBCC, ERIsInputType, SpaceType, SpinArrayType diff --git a/ebcc/eom/geom.py b/ebcc/eom/geom.py index 0601ac0b..9483c18a 100644 --- a/ebcc/eom/geom.py +++ b/ebcc/eom/geom.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: from typing import Optional - from numpy import float64, int64 + from numpy import float64 from numpy.typing import NDArray from ebcc.cc.gebcc import GEBCC, ERIsInputType, SpinArrayType @@ -32,7 +32,7 @@ class GEOM(BaseEOM): class IP_GEOM(GEOM, BaseIP_EOM): """Generalised ionisation potential equation-of-motion coupled cluster.""" - def _argsort_guesses(self, diag: NDArray[T]) -> NDArray[int64]: + def _argsort_guesses(self, diag: NDArray[T]) -> list[int]: """Sort the diagonal to inform the initial guesses.""" if self.options.koopmans: r1 = self.vector_to_amplitudes(diag)["r1"] @@ -128,7 +128,7 @@ def vector_to_amplitudes(self, vector: NDArray[T]) -> Namespace[SpinArrayType]: class EA_GEOM(GEOM, BaseEA_EOM): """Generalised electron affinity equation-of-motion coupled cluster.""" - def _argsort_guesses(self, diag: NDArray[T]) -> NDArray[int64]: + def _argsort_guesses(self, diag: NDArray[T]) -> list[int]: """Sort the diagonal to inform the initial guesses.""" if self.options.koopmans: r1 = self.vector_to_amplitudes(diag)["r1"] @@ -224,7 +224,7 @@ def vector_to_amplitudes(self, vector: NDArray[T]) -> Namespace[SpinArrayType]: class EE_GEOM(GEOM, BaseEE_EOM): """Generalised electron-electron equation-of-motion coupled cluster.""" - def _argsort_guesses(self, diag: NDArray[T]) -> NDArray[int64]: + def _argsort_guesses(self, diag: NDArray[T]) -> list[int]: """Sort the diagonal to inform the initial guesses.""" if self.options.koopmans: r1 = self.vector_to_amplitudes(diag)["r1"] diff --git a/ebcc/eom/reom.py b/ebcc/eom/reom.py index 3b6699be..53105bb1 100644 --- a/ebcc/eom/reom.py +++ b/ebcc/eom/reom.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: from typing import Optional - from numpy import float64, int64 + from numpy import float64 from numpy.typing import NDArray from ebcc.cc.rebcc import REBCC, ERIsInputType, SpinArrayType @@ -220,7 +220,7 @@ def vector_to_amplitudes(self, vector: NDArray[T]) -> Namespace[SpinArrayType]: class EE_REOM(REOM, BaseEE_EOM): """Restricted electron-electron equation-of-motion coupled cluster.""" - def _argsort_guesses(self, diag: NDArray[T]) -> NDArray[int64]: + def _argsort_guesses(self, diag: NDArray[T]) -> list[int]: """Sort the diagonal to inform the initial guesses.""" if self.options.koopmans: r1 = self.vector_to_amplitudes(diag)["r1"] diff --git a/ebcc/eom/ueom.py b/ebcc/eom/ueom.py index 049b5602..022c05bf 100644 --- a/ebcc/eom/ueom.py +++ b/ebcc/eom/ueom.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: from typing import Optional - from numpy import float64, int64 + from numpy import float64 from numpy.typing import NDArray from ebcc.cc.uebcc import UEBCC, ERIsInputType, SpinArrayType @@ -32,7 +32,7 @@ class UEOM(BaseEOM): class IP_UEOM(UEOM, BaseIP_EOM): """Unrestricted ionisation potential equation-of-motion coupled cluster.""" - def _argsort_guesses(self, diag: NDArray[T]) -> NDArray[int64]: + def _argsort_guesses(self, diag: NDArray[T]) -> list[int]: """Sort the diagonal to inform the initial guesses.""" if self.options.koopmans: r1 = self.vector_to_amplitudes(diag)["r1"] @@ -147,7 +147,7 @@ def vector_to_amplitudes(self, vector: NDArray[T]) -> Namespace[SpinArrayType]: class EA_UEOM(UEOM, BaseEA_EOM): """Unrestricted electron affinity equation-of-motion coupled cluster.""" - def _argsort_guesses(self, diag: NDArray[T]) -> NDArray[int64]: + def _argsort_guesses(self, diag: NDArray[T]) -> list[int]: """Sort the diagonal to inform the initial guesses.""" if self.options.koopmans: r1 = self.vector_to_amplitudes(diag)["r1"] @@ -262,7 +262,7 @@ def vector_to_amplitudes(self, vector: NDArray[T]) -> Namespace[SpinArrayType]: class EE_UEOM(UEOM, BaseEE_EOM): """Unrestricted electron-electron equation-of-motion coupled cluster.""" - def _argsort_guesses(self, diag: NDArray[T]) -> NDArray[int64]: + def _argsort_guesses(self, diag: NDArray[T]) -> list[int]: """Sort the diagonal to inform the initial guesses.""" if self.options.koopmans: r1 = self.vector_to_amplitudes(diag)["r1"] diff --git a/ebcc/ham/cderis.py b/ebcc/ham/cderis.py index 6edbcc91..0f49538b 100644 --- a/ebcc/ham/cderis.py +++ b/ebcc/ham/cderis.py @@ -57,7 +57,9 @@ def __getitem__(self, key: str, e2: Optional[bool] = False) -> NDArray[T]: if key_e2 not in self._members: s = 0 if not e2 else 2 coeffs = [ - numpy.asarray(self.mo_coeff[i + s][:, self.space[i + s].slice(k)], dtype=numpy.float64) + numpy.asarray( + self.mo_coeff[i + s][:, self.space[i + s].slice(k)], dtype=numpy.float64 + ) for i, k in enumerate(key) ] ijslice = ( diff --git a/ebcc/ham/elbos.py b/ebcc/ham/elbos.py index f4d925e9..5a4ec4e4 100644 --- a/ebcc/ham/elbos.py +++ b/ebcc/ham/elbos.py @@ -4,8 +4,8 @@ from typing import TYPE_CHECKING -from ebcc.ham.base import BaseElectronBoson, BaseGHamiltonian, BaseRHamiltonian, BaseUHamiltonian from ebcc import numpy as np +from ebcc.ham.base import BaseElectronBoson, BaseGHamiltonian, BaseRHamiltonian, BaseUHamiltonian if TYPE_CHECKING: from numpy import float64 diff --git a/ebcc/ham/space.py b/ebcc/ham/space.py index f8337303..5f105087 100644 --- a/ebcc/ham/space.py +++ b/ebcc/ham/space.py @@ -23,6 +23,7 @@ T = float64 B = bool_ + _slice = slice # Development note: multiplication of boolean arrays is used in place of logical or bitwise # AND functions. This is because backends are not guaranteed to support logical or bitwise @@ -124,8 +125,8 @@ def mask(self, char: str) -> NDArray[B]: "a": self.inactive_virtual, }[char] - @functools.lru_cache(maxsize=128) - def slice(self, char: str) -> slice: + @functools.lru_cache(maxsize=128) # noqa: B019 + def slice(self, char: str) -> _slice: """Convert a character corresponding to a space to a slice of that space. Args: @@ -169,7 +170,7 @@ def vmask(self, char: str) -> NDArray[B]: """ return self.mask(char)[self.virtual] - def oslice(self, char: str) -> slice: + def oslice(self, char: str) -> _slice: """Like `slice`, but returns only a slice into only the occupied sector. Args: @@ -182,7 +183,7 @@ def oslice(self, char: str) -> slice: nocc = self.nocc return slice(s.start, min(s.stop, nocc)) - def vslice(self, char: str) -> slice: + def vslice(self, char: str) -> _slice: """Like `slice`, but returns only a slice into only the virtual sector. Args: @@ -414,6 +415,7 @@ def construct_fno_space( """ # Get the MP2 1RDM solver = MP2(mf) + dm1: NDArray[T] if not amplitudes: solver.kernel() dm1 = np.astype(solver.make_rdm1(), types[float]) diff --git a/ebcc/opt/base.py b/ebcc/opt/base.py index ad45b8fb..39ce3c57 100644 --- a/ebcc/opt/base.py +++ b/ebcc/opt/base.py @@ -9,8 +9,8 @@ import numpy from pyscf import lib -from ebcc import util from ebcc import numpy as np +from ebcc import util from ebcc.core.damping import DIIS from ebcc.core.logging import ANSI, NullLogger, init_logging from ebcc.core.precision import types diff --git a/ebcc/opt/gbrueckner.py b/ebcc/opt/gbrueckner.py index 3767b2af..54a41c2e 100644 --- a/ebcc/opt/gbrueckner.py +++ b/ebcc/opt/gbrueckner.py @@ -64,8 +64,7 @@ def get_rotation_matrix( if np.linalg.det(u_tot) < 0: u_tot = _put(u_tot, np.ix_(np.arange(u_tot.shape[0]), np.array([0])), -u_tot[:, 0]) - a = scipy.linalg.logm(u_tot) - a = np.astype(np.real(a), types[float]) + a: NDArray[T] = np.astype(np.real(scipy.linalg.logm(u_tot)), types[float]) if diis is not None: a = diis.update(a, xerr=t1) diff --git a/ebcc/opt/rbrueckner.py b/ebcc/opt/rbrueckner.py index cd9b0d4c..2f9d3021 100644 --- a/ebcc/opt/rbrueckner.py +++ b/ebcc/opt/rbrueckner.py @@ -64,8 +64,7 @@ def get_rotation_matrix( if np.linalg.det(u_tot) < 0: u_tot = _put(u_tot, np.ix_(np.arange(u_tot.shape[0]), np.array([0])), -u_tot[:, 0]) - a = scipy.linalg.logm(u_tot) - a = np.astype(np.real(a), types[float]) + a: NDArray[T] = np.astype(np.real(scipy.linalg.logm(u_tot)), types[float]) if diis is not None: a = diis.update(a, xerr=t1) diff --git a/ebcc/opt/ubrueckner.py b/ebcc/opt/ubrueckner.py index 23f8be94..a5b5e25b 100644 --- a/ebcc/opt/ubrueckner.py +++ b/ebcc/opt/ubrueckner.py @@ -81,7 +81,7 @@ def get_rotation_matrix( a = np.concatenate( [np.ravel(scipy.linalg.logm(u_tot.aa)), np.ravel(scipy.linalg.logm(u_tot.bb))], axis=0 ) - a = np.astype(np.real(a), types[float]) + a: NDArray[T] = np.astype(np.real(a), types[float]) if diis is not None: xerr = np.concatenate([np.ravel(t1.aa), np.ravel(t1.bb)]) a = diis.update(a, xerr=xerr) diff --git a/ebcc/util/__init__.py b/ebcc/util/__init__.py index 2e0d7925..799476bc 100644 --- a/ebcc/util/__init__.py +++ b/ebcc/util/__init__.py @@ -1,6 +1,14 @@ """Utilities.""" -from ebcc.util.misc import Inherited, ModelNotImplemented, Namespace, Timer, regularise_tuple, prod, argsort +from ebcc.util.misc import ( + Inherited, + ModelNotImplemented, + Namespace, + Timer, + regularise_tuple, + prod, + argsort, +) from ebcc.util.einsumfunc import einsum, dirsum from ebcc.util.permutations import ( antisymmetrise_array, diff --git a/ebcc/util/misc.py b/ebcc/util/misc.py index 2d62132c..3ac09e3c 100644 --- a/ebcc/util/misc.py +++ b/ebcc/util/misc.py @@ -8,17 +8,22 @@ if TYPE_CHECKING: from abc import abstractmethod - from typing import Any, ItemsView, Iterator, KeysView, Union, ValuesView, Protocol + from typing import Any, ItemsView, Iterator, KeysView, Protocol, Union, ValuesView + + from numpy import generic + from numpy.typing import NDArray class Comparable(Protocol): """Protocol for comparable objects.""" @abstractmethod - def __lt__(self, other: C) -> bool: + def __lt__(self, other: C) -> Any: + """Check if the object is less than another.""" pass C = TypeVar("C", bound=Comparable) + T = TypeVar("T") @@ -174,7 +179,7 @@ def prod(values: Union[list[int], tuple[int, ...]]) -> int: return out -def argsort(values: list[C]) -> list[int]: +def argsort(values: Union[list[Union[float, str]], NDArray[generic]]) -> list[int]: """Return the indices that would sort the values. Args: diff --git a/ebcc/util/permutations.py b/ebcc/util/permutations.py index 1361e77b..e706778d 100644 --- a/ebcc/util/permutations.py +++ b/ebcc/util/permutations.py @@ -324,7 +324,9 @@ def compress_axes( assert sizes[char] == n else: sizes[char] = n - array = np.reshape(array, [sizes[char] ** subscript.count(char) for char in sorted(set(subscript))]) + array = np.reshape( + array, [sizes[char] ** subscript.count(char) for char in sorted(set(subscript))] + ) # For each axis type, get the necessary lower-triangular indices: indices_ndim = [ @@ -409,7 +411,9 @@ def decompress_axes( assert sizes[char] == n else: sizes[char] = n - array = np.reshape(array, [sizes[char] ** subscript.count(char) for char in sorted(set(subscript))]) + array = np.reshape( + array, [sizes[char] ** subscript.count(char) for char in sorted(set(subscript))] + ) # Check the symmetry string, and compress it: n = 0 @@ -442,12 +446,13 @@ def decompress_axes( array = _put(array, indices_perm, array_flat * util.prod(signs)) # Reshape array to non-flattened format - array = np.reshape(array, ( - sum([(sizes[char],) * subscript.count(char) for char in sorted(set(subscript))], tuple()) - )) + array = np.reshape( + array, + (sum([(sizes[char],) * subscript.count(char) for char in sorted(set(subscript))], tuple())), + ) # Undo transpose: - arg = tuple(util.argsort(arg)) + arg = tuple(util.argsort(list(arg))) array = np.transpose(array, arg) return array diff --git a/pyproject.toml b/pyproject.toml index 18f14e18..086e2493 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -184,6 +184,10 @@ ignore_missing_imports = true module = "tensorflow.*" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "ctf.*" +ignore_missing_imports = true + [tool.coverage.run] branch = true source = [