diff --git a/src/click/__init__.py b/src/click/__init__.py index 1aa547c57a..73f26798ca 100644 --- a/src/click/__init__.py +++ b/src/click/__init__.py @@ -52,6 +52,7 @@ from .types import BOOL as BOOL from .types import Choice as Choice from .types import DateTime as DateTime +from .types import EnumChoice as EnumChoice from .types import File as File from .types import FLOAT as FLOAT from .types import FloatRange as FloatRange diff --git a/src/click/types.py b/src/click/types.py index 8d9750b63c..3273ada659 100644 --- a/src/click/types.py +++ b/src/click/types.py @@ -1,6 +1,7 @@ from __future__ import annotations import collections.abc as cabc +import enum import os import stat import sys @@ -336,6 +337,23 @@ def shell_complete( return [CompletionItem(c) for c in matched] +class EnumChoice(Choice): + def __init__(self, enum_type: type[enum.Enum], case_sensitive: bool = True): + super().__init__( + choices=[element.name for element in enum_type], + case_sensitive=case_sensitive, + ) + self.enum_type = enum_type + + def convert( + self, value: t.Any, param: Parameter | None, ctx: Context | None + ) -> t.Any: + value = super().convert(value=value, param=param, ctx=ctx) + if value is None: + return None + return self.enum_type[value] + + class DateTime(ParamType): """The DateTime type converts date strings into `datetime` objects. diff --git a/tests/test_basic.py b/tests/test_basic.py index d68b96299b..517fcc3245 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,3 +1,4 @@ +import enum import os from itertools import chain @@ -6,6 +7,15 @@ import click +class MyEnum(enum.Enum): + """Dummy enum for unit tests.""" + + ONE = "one" + TWO = "two" + THREE = "three" + ONE_ALIAS = ONE + + def test_basic_functionality(runner): @click.command() def cli(): @@ -403,6 +413,48 @@ def cli(method): assert "{foo|bar|baz}" in result.output +def test_enum_choice_option(runner): + @click.command() + @click.option("--number", type=click.EnumChoice(MyEnum)) + def cli(number): + click.echo(number) + + result = runner.invoke(cli, ["--number=ONE"]) + assert not result.exception + assert result.output == "MyEnum.ONE\n" + + result = runner.invoke(cli, ["--number=meh"]) + assert result.exit_code == 2 + assert ( + "Invalid value for '--number': 'meh' is not one of 'ONE', 'TWO', 'THREE'." + in result.output + ) + + result = runner.invoke(cli, ["--help"]) + assert "--number [ONE|TWO|THREE]" in result.output + + +def test_enum_choice_argument(runner): + @click.command() + @click.argument("number", type=click.EnumChoice(MyEnum)) + def cli(number): + click.echo(number) + + result = runner.invoke(cli, ["ONE"]) + assert not result.exception + assert result.output == "MyEnum.ONE\n" + + result = runner.invoke(cli, ["meh"]) + assert result.exit_code == 2 + assert ( + "Invalid value for '{ONE|TWO|THREE}': 'meh' is not one of 'ONE', " + "'TWO', 'THREE'." in result.output + ) + + result = runner.invoke(cli, ["--help"]) + assert "{ONE|TWO|THREE}" in result.output + + def test_datetime_option_default(runner): @click.command() @click.option("--start_date", type=click.DateTime()) diff --git a/tests/test_info_dict.py b/tests/test_info_dict.py index 79d39ee513..324ee12b30 100644 --- a/tests/test_info_dict.py +++ b/tests/test_info_dict.py @@ -1,8 +1,11 @@ +import enum + import pytest import click.types # Common (obj, expect) pairs used to construct multiple tests. + STRING_PARAM_TYPE = (click.STRING, {"param_type": "String", "name": "text"}) INT_PARAM_TYPE = (click.INT, {"param_type": "Int", "name": "integer"}) BOOL_PARAM_TYPE = (click.BOOL, {"param_type": "Bool", "name": "boolean"}) @@ -91,6 +94,15 @@ ) +class MyEnum(enum.Enum): + """Dummy enum for unit tests.""" + + ONE = "one" + TWO = "two" + THREE = "three" + ONE_ALIAS = ONE + + @pytest.mark.parametrize( ("obj", "expect"), [ @@ -115,6 +127,16 @@ }, id="Choice ParamType", ), + pytest.param( + click.EnumChoice(MyEnum), + { + "param_type": "EnumChoice", + "name": "choice", + "choices": ["ONE", "TWO", "THREE"], + "case_sensitive": True, + }, + id="EnumChoice ParamType", + ), pytest.param( click.DateTime(["%Y-%m-%d"]), {"param_type": "DateTime", "name": "datetime", "formats": ["%Y-%m-%d"]}, diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 502e654a37..fdc6a825aa 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -1,8 +1,19 @@ +import enum + import click CONTEXT_SETTINGS = dict(token_normalize_func=lambda x: x.lower()) +class MyEnum(enum.Enum): + """Dummy enum for unit tests.""" + + ONE = "one" + TWO = "two" + THREE = "three" + ONE_ALIAS = ONE + + def test_option_normalization(runner): @click.command(context_settings=CONTEXT_SETTINGS) @click.option("--foo") @@ -25,6 +36,16 @@ def cli(choice): assert result.output == "Foo\n" +def test_enum_choice_normalization(runner): + @click.command(context_settings=CONTEXT_SETTINGS) + @click.option("--choice", type=click.EnumChoice(MyEnum)) + def cli(choice): + click.echo(choice) + + result = runner.invoke(cli, ["--CHOICE", "ONE"]) + assert result.output == "MyEnum.ONE\n" + + def test_command_normalization(runner): @click.group(context_settings=CONTEXT_SETTINGS) def cli(): diff --git a/tests/test_options.py b/tests/test_options.py index 7397f36676..f01922e385 100644 --- a/tests/test_options.py +++ b/tests/test_options.py @@ -1,3 +1,4 @@ +import enum import os import re @@ -7,6 +8,15 @@ from click import Option +class MyEnum(enum.Enum): + """Dummy enum for unit tests.""" + + ONE = "one" + TWO = "two" + THREE = "three" + ONE_ALIAS = ONE + + def test_prefixes(runner): @click.command() @click.option("++foo", is_flag=True, help="das foo") @@ -571,6 +581,67 @@ def cmd(foo): assert result.output == "Apple\n" +def test_missing_enum_choice(runner): + @click.command() + @click.option("--foo", type=click.EnumChoice(MyEnum), required=True) + def cmd(foo): + click.echo(foo) + + result = runner.invoke(cmd) + assert result.exit_code == 2 + error, separator, choices = result.output.partition("Choose from") + assert "Error: Missing option '--foo'. " in error + assert "Choose from" in separator + assert "ONE" in choices + assert "TWO" in choices + assert "THREE" in choices + assert "ONE_ALIAS" not in choices + + +def test_case_insensitive_enum_choice(runner): + @click.command() + @click.option("--foo", type=click.EnumChoice(MyEnum, case_sensitive=False)) + def cmd(foo): + click.echo(foo) + + result = runner.invoke(cmd, ["--foo", "one"]) + assert result.exit_code == 0 + assert result.output == "MyEnum.ONE\n" + + result = runner.invoke(cmd, ["--foo", "tHREE"]) + assert result.exit_code == 0 + assert result.output == "MyEnum.THREE\n" + + result = runner.invoke(cmd, ["--foo", "Two"]) + assert result.exit_code == 0 + assert result.output == "MyEnum.TWO\n" + + @click.command() + @click.option("--foo", type=click.EnumChoice(MyEnum)) + def cmd2(foo): + click.echo(foo) + + result = runner.invoke(cmd2, ["--foo", "one"]) + assert result.exit_code == 2 + + result = runner.invoke(cmd2, ["--foo", "tHREE"]) + assert result.exit_code == 2 + + result = runner.invoke(cmd2, ["--foo", "TWO"]) + assert result.exit_code == 0 + + +def test_case_insensitive_enum_choice_returned_exactly(runner): + @click.command() + @click.option("--foo", type=click.EnumChoice(MyEnum, case_sensitive=False)) + def cmd(foo): + click.echo(foo) + + result = runner.invoke(cmd, ["--foo", "ONE"]) + assert result.exit_code == 0 + assert result.output == "MyEnum.ONE\n" + + def test_option_help_preserve_paragraphs(runner): @click.command() @click.option(