From 1e633f06bf12b5222595398a9257abcaed52b1b1 Mon Sep 17 00:00:00 2001 From: Tomas Roun Date: Sun, 24 Mar 2024 13:16:49 +0100 Subject: [PATCH] fix warnings, use ruff to isort --- .github/workflows/ci.yml | 4 +--- .pre-commit-config.yaml | 9 +++------ pivotal/simplex.py | 23 +++++++++++++++++------ pyproject.toml | 12 +++++++----- tests/__init__.py | 0 tests/test_api.py | 4 ++-- tests/test_simplex.py | 4 ++-- 7 files changed, 32 insertions(+), 24 deletions(-) create mode 100644 tests/__init__.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ab4b05f..697e45e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,9 +25,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: python -m pip install -e .[dev] - - name: Run ruff - run: ruff check --verbose - - name: Run linters + - name: Run linters & formatters run: pre-commit run --all-files - name: Test with pytest run: pytest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4905f5a..ea7ba06 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,14 +5,11 @@ repos: hooks: # Run the linter. - id: ruff - args: [--fix, --verbose] + args: [--fix] + # Run the formatter. + - id: ruff-format - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.3.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer - - repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - name: isort (python) diff --git a/pivotal/simplex.py b/pivotal/simplex.py index 56ce13b..1e3744e 100644 --- a/pivotal/simplex.py +++ b/pivotal/simplex.py @@ -1,13 +1,21 @@ import math import warnings +from collections.abc import Callable from enum import Enum, auto -from typing import Literal +from typing import Literal, TypeVar import numpy as np from pivotal.errors import Infeasible, Unbounded -from pivotal.expressions import (Constraint, Equal, Expression, GreaterOrEqual, LessOrEqual, get_variable_coeffs, - get_variable_names) +from pivotal.expressions import ( + Constraint, + Equal, + Expression, + GreaterOrEqual, + LessOrEqual, + get_variable_coeffs, + get_variable_names, +) class ProgramType(Enum): @@ -15,8 +23,11 @@ class ProgramType(Enum): MAX = auto() -def suppress_divide_by_zero_warning(fn) -> None: - def _fn_suppressed(*args, **kwargs): +_T = TypeVar("_T") + + +def suppress_divide_by_zero_warning(fn: Callable[..., _T]) -> None: + def _fn_suppressed(*args, **kwargs) -> _T: with warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) return fn(*args, **kwargs) @@ -71,7 +82,7 @@ def __repr__(self) -> str: class Tableau: - def __init__( + def __init__( # noqa: PLR0913 self, A: np.ndarray, b: np.ndarray, c: np.ndarray, pivots: Pivots | None = None, *, tolerance: float = 1e-6 ) -> None: self.M = np.block([[np.atleast_2d(c), np.atleast_2d(0)], [A, np.atleast_2d(b).T]]) diff --git a/pyproject.toml b/pyproject.toml index d1cd329..34e66ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,16 +31,12 @@ dependencies = ["numpy"] dynamic = ["version"] [project.optional-dependencies] -dev = ["pytest", "ruff", "isort", "pre-commit"] +dev = ["pytest", "ruff", "pre-commit"] [project.urls] Homepage = "https://github.com/tomasr8/pivotal" Github = "https://github.com/tomasr8/pivotal" -[tool.isort] -line_length = 120 -lines_after_imports = 2 - [tool.ruff.lint] select = ["ALL"] ignore = [ @@ -58,9 +54,15 @@ ignore = [ "S101", ] +[tool.ruff.lint.per-file-ignores] +"tests/*" = ["ANN001", "ANN201", "PLR2004"] + [tool.ruff] line-length = 120 +[tool.ruff.lint.isort] +lines-after-imports = 2 + [tool.setuptools] packages = ["pivotal"] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_api.py b/tests/test_api.py index 8181083..77eb50a 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -8,7 +8,7 @@ def assert_solution_almost_equal(expected, actual): __traceback_hide__ = True # noqa: F841 assert np.isclose(expected[0], actual[0]) - assert_allclose(expected[1], [actual[1][name] for name in sorted(list(actual[1].keys()))], atol=1e-8) + assert_allclose(expected[1], [actual[1][name] for name in sorted(actual[1].keys())], atol=1e-8) def assert_equal(a, b): @@ -19,7 +19,7 @@ def assert_equal(a, b): def assert_array_equal(a, b): __traceback_hide__ = True # noqa: F841 assert len(a) == len(b) - for x, y in zip(a, b): + for x, y in zip(a, b, strict=True): assert_equal(x, y) diff --git a/tests/test_simplex.py b/tests/test_simplex.py index d5a39d8..206e3b9 100644 --- a/tests/test_simplex.py +++ b/tests/test_simplex.py @@ -189,7 +189,7 @@ def test_program_1(): value, X = _solve(Tableau(A, b, c, Pivots({})), max_iterations=math.inf, tolerance=1e-6) X_true = [0.81818182, 1.72727273] assert value == pytest.approx(4.27272727) - for x, xt in zip(X, X_true): + for x, xt in zip(X, X_true, strict=True): assert x == pytest.approx(xt) @@ -208,5 +208,5 @@ def test_program_2(): value, X = _solve(Tableau(A, b, c, Pivots({})), max_iterations=math.inf, tolerance=1e-6) X_true = [1, 0, 7] assert value == pytest.approx(-300) - for x, xt in zip(X, X_true): + for x, xt in zip(X, X_true, strict=True): assert x == pytest.approx(xt)