diff --git a/changelog.md b/changelog.md index faaa32a4..410709df 100644 --- a/changelog.md +++ b/changelog.md @@ -1,3 +1,20 @@ +Upcoming (TBD) +============== + +Features +-------- +* Update password handling functionality (#341): + 1. Add ability to use the -p/--password/--pass options to launch a password prompt (or to enter cleartext as before) + 2. Check for the -p/--password/--pass options at the earliest possible point to reduce prompt load time (still limited by startup time) + 3. Allow for an empty string password from the envvar MYSQL_PWD + 4. Clarify password option hierachy: + 1. -p / --pass/--password CLI options + 2. envvar (MYSQL_PWD) + 3. DSN (mysql://user:password) + 4. cnf (.my.cnf / etc) + 5. --password-file CLI option + + 1.44.2 (2026/01/13) ============== diff --git a/mycli/main.py b/mycli/main.py index 7258712c..6b4dd865 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -64,7 +64,7 @@ from mycli.packages.tabular_output import sql_format from mycli.packages.toolkit.history import FileHistoryWithTimestamp from mycli.sqlcompleter import SQLCompleter -from mycli.sqlexecute import ERROR_CODE_ACCESS_DENIED, FIELD_TYPES, SQLExecute +from mycli.sqlexecute import FIELD_TYPES, SQLExecute try: import paramiko @@ -460,7 +460,7 @@ def connect( self, database: str | None = "", user: str | None = "", - passwd: str | None = "", + passwd: str | None = None, host: str | None = "", port: str | int | None = "", socket: str | None = "", @@ -528,10 +528,19 @@ def connect( # if the passwd is not specified try to set it using the password_file option password_from_file = self.get_password_from_file(password_file) passwd = passwd if isinstance(passwd, str) else password_from_file - passwd = '' if passwd is None else passwd - # Connect to the database. + # password hierarchy + # 1. -p / --pass/--password CLI options + # 2. envvar (MYSQL_PWD) + # 3. DSN (mysql://user:password) + # 4. cnf (.my.cnf / etc) + # 5. --password-file CLI option + + # if no password was found from all of the above sources, ask for a password + if passwd is None: + passwd = click.prompt("Enter password", hide_input=True, show_default=False, default='', type=str, err=True) + # Connect to the database. def _connect() -> None: try: self.sqlexecute = SQLExecute( @@ -552,31 +561,7 @@ def _connect() -> None: init_command, ) except pymysql.OperationalError as e1: - if e1.args[0] == ERROR_CODE_ACCESS_DENIED: - if password_from_file is not None: - new_passwd = password_from_file - else: - new_passwd = click.prompt( - f"Password for {user}", hide_input=True, show_default=False, default='', type=str, err=True - ) - self.sqlexecute = SQLExecute( - database, - user, - new_passwd, - host, - int_port, - socket, - charset, - use_local_infile, - ssl_config_or_none, - ssh_user, - ssh_host, - int(ssh_port) if ssh_port else None, - ssh_password, - ssh_key_filename, - init_command, - ) - elif e1.args[0] == HANDSHAKE_ERROR and ssl is not None and ssl.get("mode", None) == "auto": + if e1.args[0] == HANDSHAKE_ERROR and ssl is not None and ssl.get("mode", None) == "auto": try: self.sqlexecute = SQLExecute( database, @@ -595,33 +580,8 @@ def _connect() -> None: ssh_key_filename, init_command, ) - except pymysql.OperationalError as e2: - if e2.args[0] == ERROR_CODE_ACCESS_DENIED: - if password_from_file is not None: - new_passwd = password_from_file - else: - new_passwd = click.prompt( - f"Password for {user}", hide_input=True, show_default=False, default='', type=str, err=True - ) - self.sqlexecute = SQLExecute( - database, - user, - new_passwd, - host, - int_port, - socket, - charset, - use_local_infile, - None, - ssh_user, - ssh_host, - int(ssh_port) if ssh_port else None, - ssh_password, - ssh_key_filename, - init_command, - ) - else: - raise e2 + except Exception as e2: + raise e2 else: raise e1 @@ -1482,13 +1442,36 @@ def get_last_query(self) -> str | None: return self.query_history[-1][0] if self.query_history else None -@click.command() +# custom parsing class for click options to let us know what order +# the user provides the parameters in on the CLI +class ProvidedParamsOrder(click.Command): + def parse_args(self, ctx, args): + # run parser to get the options and their order + parser = self.make_parser(ctx) + opts, _, param_order = parser.parse_args(args=list(args)) + + # store the ordered parameters in the context object for later use + ctx.obj = {'param_order': param_order} + + # proceed with parsing as normal + return super().parse_args(ctx, args) + + +@click.command(cls=ProvidedParamsOrder) @click.option("-h", "--host", envvar="MYSQL_HOST", help="Host address of the database.") @click.option("-P", "--port", envvar="MYSQL_TCP_PORT", type=int, help="Port number to use for connection. Honors $MYSQL_TCP_PORT.") @click.option("-u", "--user", help="User name to connect to the database.") @click.option("-S", "--socket", envvar="MYSQL_UNIX_PORT", help="The socket file to use for connection.") -@click.option("-p", "--password", "password", envvar="MYSQL_PWD", type=str, help="Password to connect to the database.") -@click.option("--pass", "password", envvar="MYSQL_PWD", type=str, help="Password to connect to the database.") +@click.option( + "-p", + "--pass", + "--password", + "password", + is_flag=False, + flag_value="MYCLI_ASK_PASSWORD", + type=str, + help="Prompt for (or enter in cleartext) password to connect to the database.", +) @click.option("--ssh-user", help="User name to connect to ssh server.") @click.option("--ssh-host", help="Host name to connect to ssh server.") @click.option("--ssh-port", default=22, help="Port to connect to ssh server.") @@ -1548,9 +1531,11 @@ def get_last_query(self) -> str | None: @click.option( "--password-file", type=click.Path(), help="File or FIFO path containing the password to connect to the db if not specified otherwise." ) -@click.argument("database", default="", nargs=1) +@click.argument("database", default=None, nargs=1) +@click.pass_context def cli( - database: str, + ctx: click.Context, + database: str | None, user: str | None, host: str | None, port: int | None, @@ -1603,6 +1588,25 @@ def cli( - mycli mysql://my_user@my_host.com:3306/my_database """ + # get an ordered list of params provided, excluding the + # database argument as that will be provided by default + param_order = [param.name for param in ctx.obj['param_order'] if param.name != "database"] + + # if password is not the flag value, is the last param, and + # database has no value, then assume the password value is + # actually the database and prompt for password + if password != "MYCLI_ASK_PASSWORD" and len(param_order) >= 1 and not database and param_order[-1] == "password": + database = password + password = click.prompt("Enter password", hide_input=True, show_default=False, default='', type=str, err=True) + # if user passes the --p* flag, ask for the password right away + # to reduce lag as much as possible + elif password == "MYCLI_ASK_PASSWORD": + password = click.prompt("Enter password", hide_input=True, show_default=False, default='', type=str, err=True) + elif password is None and os.environ.get("MYSQL_PWD") is not None: + # getting the envvar ourselves because the envvar from a click + # option cannot be an empty string, but a password can be + password = os.environ.get("MYSQL_PWD") + mycli = MyCli( prompt=prompt, logfile=logfile, diff --git a/test/test_main.py b/test/test_main.py index 076bd986..8cda15ba 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -47,7 +47,7 @@ def test_ssl_mode_on(executor, capsys): sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'" result = runner.invoke(cli, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode], input=sql) result_dict = next(csv.DictReader(result.stdout.split("\n"))) - ssl_cipher = result_dict["VARIABLE_VALUE"] + ssl_cipher = result_dict.get("VARIABLE_VALUE", None) assert ssl_cipher @@ -58,7 +58,7 @@ def test_ssl_mode_auto(executor, capsys): sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'" result = runner.invoke(cli, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode], input=sql) result_dict = next(csv.DictReader(result.stdout.split("\n"))) - ssl_cipher = result_dict["VARIABLE_VALUE"] + ssl_cipher = result_dict.get("VARIABLE_VALUE", None) assert ssl_cipher @@ -69,7 +69,7 @@ def test_ssl_mode_off(executor, capsys): sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'" result = runner.invoke(cli, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode], input=sql) result_dict = next(csv.DictReader(result.stdout.split("\n"))) - ssl_cipher = result_dict["VARIABLE_VALUE"] + ssl_cipher = result_dict.get("VARIABLE_VALUE", None) assert not ssl_cipher @@ -80,7 +80,7 @@ def test_ssl_mode_overrides_ssl(executor, capsys): sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'" result = runner.invoke(cli, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode, "--ssl"], input=sql) result_dict = next(csv.DictReader(result.stdout.split("\n"))) - ssl_cipher = result_dict["VARIABLE_VALUE"] + ssl_cipher = result_dict.get("VARIABLE_VALUE", None) assert not ssl_cipher @@ -91,7 +91,7 @@ def test_ssl_mode_overrides_no_ssl(executor, capsys): sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'" result = runner.invoke(cli, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode, "--no-ssl"], input=sql) result_dict = next(csv.DictReader(result.stdout.split("\n"))) - ssl_cipher = result_dict["VARIABLE_VALUE"] + ssl_cipher = result_dict.get("VARIABLE_VALUE", None) assert ssl_cipher