From 60bea668f7bf4359a447487555b9209ae5b1e07b Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Mon, 12 Jun 2023 23:10:59 -0700 Subject: [PATCH] Add TorchFix linter (#2179) * Add TorchFix linter * Move comments to separate lines * Change assert_allclose to assert_close --- .flake8 | 8 ++++++-- .pre-commit-config.yaml | 1 + test/torchtext_unittest/data/test_jit.py | 6 +++--- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/.flake8 b/.flake8 index e09b485892..0d4b9398ff 100644 --- a/.flake8 +++ b/.flake8 @@ -1,9 +1,13 @@ [flake8] ignore = E401,E402,E501,E722,W503,W504,F821,B006,B007,B008,B009, - E203 # https://github.com/PyCQA/pycodestyle/issues/373 + # https://github.com/PyCQA/pycodestyle/issues/373 + E203 select = B,C,E,F,P,T4,W,B9, - D417 # Missing argument descriptions in the docstring + # Missing argument descriptions in the docstring + D417, + # TorchFix + TOR max-line-length = 120 exclude = docs/source,third_party diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e9e64dd07d..fb071843f4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,5 +34,6 @@ repos: - id: flake8 additional_dependencies: - flake8-docstrings == 1.6.0 + - torchfix == 0.0.1 args: - --config=.flake8 diff --git a/test/torchtext_unittest/data/test_jit.py b/test/torchtext_unittest/data/test_jit.py index 9dcb90f658..c6f33d385a 100644 --- a/test/torchtext_unittest/data/test_jit.py +++ b/test/torchtext_unittest/data/test_jit.py @@ -1,5 +1,5 @@ import torch -from torch.testing import assert_allclose +from torch.testing import assert_close from torchtext.nn import InProjContainer, MultiheadAttentionContainer, ScaledDotProduct from ..common.torchtext_test_case import TorchtextTestCase @@ -26,5 +26,5 @@ def test_torchscript_multiheadattention(self) -> None: ts_MHA = torch.jit.script(MHA) ts_mha_output, ts_attn_weights = ts_MHA(query, key, value, attn_mask=attn_mask) - assert_allclose(mha_output, ts_mha_output) - assert_allclose(attn_weights, ts_attn_weights) + assert_close(mha_output, ts_mha_output) + assert_close(attn_weights, ts_attn_weights)