diff --git a/CHANGES.md b/CHANGES.md index edca0dcdad4..7e8826e0c72 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -14,6 +14,8 @@ +- Use parentheses with equality check in walrus/assigment statements (#2770) + ### _Blackd_ diff --git a/src/black/__init__.py b/src/black/__init__.py index c4ec99b441f..2afaee13bdf 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -1095,7 +1095,7 @@ def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileCon Operate cell-by-cell, only on code cells, only for Python notebooks. If the ``.ipynb`` originally had a trailing newline, it'll be preserved. """ - trailing_newline = src_contents[-1] == "\n" + trailing_newline = (src_contents[-1] == "\n") modified = False nb = json.loads(src_contents) validate_metadata(nb) diff --git a/src/black/linegen.py b/src/black/linegen.py index 4dc242a1dfe..2e4d7c68815 100644 --- a/src/black/linegen.py +++ b/src/black/linegen.py @@ -98,6 +98,29 @@ def visit_default(self, node: LN) -> Iterator[Line]: self.current_line.append(node) yield from super().visit_default(node) + def visit_comparison(self, node: Node) -> Iterator[Line]: + parent: Optional[Node] = node.parent + grandparent: Optional[Node] = parent.parent if parent else None + if ( + # It is a preview style feature only + Preview.parentheses_in_equality_check_assignments in self.mode + and parent is not None + and Leaf(token.EQEQUAL, "==") in node.children + and ( + parent.type == syms.namedexpr_test + or ( + grandparent is not None + and grandparent.type in (syms.expr_stmt, syms.annassign) + ) + ) + ): + lpar = Leaf(token.LPAR, "(") + rpar = Leaf(token.RPAR, ")") + node.insert_child(0, lpar) + node.insert_child(len(node.children), rpar) + + yield from self.visit_default(node) + def visit_INDENT(self, node: Leaf) -> Iterator[Line]: """Increase indentation level, maybe yield a line.""" # In blib2to3 INDENT never holds comments. diff --git a/src/black/mode.py b/src/black/mode.py index 455ed36e27e..7d5dfb51103 100644 --- a/src/black/mode.py +++ b/src/black/mode.py @@ -128,6 +128,7 @@ class Preview(Enum): string_processing = auto() hug_simple_powers = auto() + parentheses_in_equality_check_assignments = auto() class Deprecated(UserWarning): diff --git a/src/black/parsing.py b/src/black/parsing.py index db48ae4baf5..d0f6d0cc014 100644 --- a/src/black/parsing.py +++ b/src/black/parsing.py @@ -24,7 +24,7 @@ ast3: Any -_IS_PYPY = platform.python_implementation() == "PyPy" +_IS_PYPY = (platform.python_implementation() == "PyPy") try: from typed_ast import ast3 diff --git a/src/black/trans.py b/src/black/trans.py index 74d052fe2dc..9d6115463cf 100644 --- a/src/black/trans.py +++ b/src/black/trans.py @@ -1479,7 +1479,7 @@ def passes_all_checks(i: Index) -> bool: True iff ALL of the conditions listed in the 'Transformations' section of this classes' docstring would be be met by returning @i. """ - is_space = string[i] == " " + is_space = (string[i] == " ") is_not_escaped = True j = i - 1 diff --git a/src/black_primer/lib.py b/src/black_primer/lib.py index 13724f431ce..99181c75295 100644 --- a/src/black_primer/lib.py +++ b/src/black_primer/lib.py @@ -29,7 +29,7 @@ TEN_MINUTES_SECONDS = 600 -WINDOWS = system() == "Windows" +WINDOWS = (system() == "Windows") BLACK_BINARY = "black.exe" if WINDOWS else "black" GIT_BINARY = "git.exe" if WINDOWS else "git" LOG = logging.getLogger(__name__) @@ -158,7 +158,7 @@ async def black_run( ) return - stdin_test = project_name.upper() == "STDIN" + stdin_test = (project_name.upper() == "STDIN") cmd = [str(which(BLACK_BINARY))] if "cli_arguments" in project_config and project_config["cli_arguments"]: cmd.extend(_flatten_cli_args(project_config["cli_arguments"])) @@ -348,7 +348,7 @@ async def project_runner( continue repo_path: Optional[Path] = Path(__file__) - stdin_project = project_name.upper() == "STDIN" + stdin_project = (project_name.upper() == "STDIN") if not stdin_project: repo_path = await git_checkout_or_rebase(work_path, project_config, rebase) if not repo_path: diff --git a/src/blackd/middlewares.py b/src/blackd/middlewares.py index 97994ecc1df..c995f225b3d 100644 --- a/src/blackd/middlewares.py +++ b/src/blackd/middlewares.py @@ -10,7 +10,7 @@ def cors(allow_headers: Iterable[str]) -> Middleware: @middleware async def impl(request: Request, handler: Handler) -> StreamResponse: - is_options = request.method == "OPTIONS" + is_options = (request.method == "OPTIONS") is_preflight = is_options and "Access-Control-Request-Method" in request.headers if is_preflight: resp = StreamResponse() diff --git a/tests/data/paren_eq_check_in_assigments.py b/tests/data/paren_eq_check_in_assigments.py new file mode 100644 index 00000000000..d0500f9dc16 --- /dev/null +++ b/tests/data/paren_eq_check_in_assigments.py @@ -0,0 +1,63 @@ +match_count += new_value == old_value + +on_windows: bool = (os.name == "nt") + +implementation_version = ( + platform.python_version() if platform.python_implementation() == "CPython" else "Unknown" +) + +is_mac = platform.system() == 'Darwin' + +s = y == 2 + y == 4 + +name1 = name2 = name3 + +name1 == name2 == name3 + +check_sockets(on_windows=os.name == "nt") + +a = b in c and b == d + +a = b == c == d + +a = b == c in d + +a = b >= c == True + +a = b in c + +a = b > c + +# output + +match_count += (new_value == old_value) + +on_windows: bool = (os.name == "nt") + +implementation_version = ( + platform.python_version() + if platform.python_implementation() == "CPython" + else "Unknown" +) + +is_mac = (platform.system() == "Darwin") + +s = (y == 2 + y == 4) + +name1 = name2 = name3 + +name1 == name2 == name3 + +check_sockets(on_windows=os.name == "nt") + +a = b in c and b == d + +a = (b == c == d) + +a = (b == c in d) + +a = (b >= c == True) + +a = b in c + +a = b > c diff --git a/tests/test_black.py b/tests/test_black.py index b1bf1772550..50f1e112e75 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -162,6 +162,7 @@ def test_piping(self) -> None: black.main, [ "-", + "--preview", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}", f"--config={EMPTY_CONFIG}", diff --git a/tests/test_format.py b/tests/test_format.py index 04eda43d5cf..13682088ca9 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -7,6 +7,7 @@ import black from tests.util import ( DEFAULT_MODE, + PREVIEW_MODE, PY36_VERSIONS, THIS_DIR, assert_format, @@ -79,6 +80,7 @@ "long_strings__edge_case", "long_strings__regression", "percent_precedence", + "paren_eq_check_in_assigments", ] SOURCES: List[str] = [ @@ -144,13 +146,13 @@ def test_simple_format(filename: str) -> None: @pytest.mark.parametrize("filename", PREVIEW_CASES) def test_preview_format(filename: str) -> None: - check_file(filename, black.Mode(preview=True)) + check_file(filename, PREVIEW_MODE) @pytest.mark.parametrize("filename", SOURCES) def test_source_is_formatted(filename: str) -> None: path = THIS_DIR.parent / filename - check_file(str(path), DEFAULT_MODE, data=False) + check_file(str(path), PREVIEW_MODE, data=False) # =============== # diff --git a/tests/util.py b/tests/util.py index 8755111f7c5..0bb8b2a9c7d 100644 --- a/tests/util.py +++ b/tests/util.py @@ -25,6 +25,7 @@ } DEFAULT_MODE = black.Mode() +PREVIEW_MODE = black.Mode(preview=True) ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True) fs = partial(black.format_str, mode=DEFAULT_MODE)