diff --git a/src/Parsers/FunctionSecretArgumentsFinder.h b/src/Parsers/FunctionSecretArgumentsFinder.h index 9ab6689637d0..27b3ed00635e 100644 --- a/src/Parsers/FunctionSecretArgumentsFinder.h +++ b/src/Parsers/FunctionSecretArgumentsFinder.h @@ -241,7 +241,7 @@ class FunctionSecretArgumentsFinder break; if (f->name() == "headers") result.nested_maps.push_back(f->name()); - else if (f->name() != "extra_credentials") + else if (f->name() != "extra_credentials" && f->name() != "equals") break; count -= 1; } @@ -266,6 +266,8 @@ class FunctionSecretArgumentsFinder return; } + findSecretNamedArgument("secret_access_key", url_arg_idx); + /// We should check other arguments first because we don't need to do any replacement in case of /// s3('url', NOSIGN, 'format' [, 'compression'] [, extra_credentials(..)] [, headers(..)]) /// s3('url', 'format', 'structure' [, 'compression'] [, extra_credentials(..)] [, headers(..)]) @@ -625,6 +627,8 @@ class FunctionSecretArgumentsFinder return; } + findSecretNamedArgument("secret_access_key", 0); + /// We should check other arguments first because we don't need to do any replacement in case of /// S3('url', NOSIGN, 'format' [, 'compression'] [, extra_credentials(..)] [, headers(..)]) /// S3('url', 'format', 'compression' [, extra_credentials(..)] [, headers(..)]) diff --git a/src/Storages/ObjectStorage/S3/Configuration.cpp b/src/Storages/ObjectStorage/S3/Configuration.cpp index 92f4ffeba7b2..a89b62dd3952 100644 --- a/src/Storages/ObjectStorage/S3/Configuration.cpp +++ b/src/Storages/ObjectStorage/S3/Configuration.cpp @@ -294,18 +294,103 @@ bool StorageS3Configuration::collectCredentials(ASTPtr maybe_credentials, S3::S3 return true; } +template +static std::optional getFromPositionOrKeyValue( + const std::string & key, + const ASTs & args, + const std::unordered_map & engine_args_to_idx, + const std::unordered_map & key_value_args) +{ + if (auto arg_it = engine_args_to_idx.find(key); arg_it != engine_args_to_idx.end()) + return checkAndGetLiteralArgument(args[arg_it->second], key); + + if (auto arg_it = key_value_args.find(key); arg_it != key_value_args.end()) + return arg_it->second.safeGet(); + + return std::nullopt; +}; + +static std::unordered_map parseKeyValueArguments(const ASTs & function_args, ContextPtr context) +{ + std::unordered_map key_value_args; + for (const auto & arg : function_args) + { + const auto * function_ast = arg->as(); + if (!function_ast || function_ast->name != "equals") + continue; + + auto * args_expr = assert_cast(function_ast->arguments.get()); + auto & children = args_expr->children; + if (children.size() != 2) + { + throw Exception( + ErrorCodes::BAD_ARGUMENTS, + "Key value argument is incorrect: expected 2 arguments, got {}", + children.size()); + } + + auto key_literal = evaluateConstantExpressionOrIdentifierAsLiteral(children[0], context); + auto value_literal = evaluateConstantExpressionOrIdentifierAsLiteral(children[1], context); + + auto arg_name_value = key_literal->as()->value; + if (arg_name_value.getType() != Field::Types::Which::String) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Expected string as credential name"); + + auto arg_name = arg_name_value.safeGet(); + auto arg_value = value_literal->as()->value; + + auto inserted = key_value_args.emplace(arg_name, arg_value).second; + if (!inserted) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Duplicate key value argument: {}", arg_name); + } + return key_value_args; +} + +static ASTs::iterator getFirstKeyValueArgument(ASTs & args) +{ + ASTs::iterator first_key_value_arg_it = args.end(); + for (auto * it = args.begin(); it != args.end(); ++it) + { + const auto * function_ast = (*it)->as(); + if (function_ast && function_ast->name == "equals") + { + if (first_key_value_arg_it == args.end()) + first_key_value_arg_it = it; + } + else if (first_key_value_arg_it != args.end()) + { + throw Exception( + ErrorCodes::BAD_ARGUMENTS, + "Expected positional arguments to go before key-value arguments"); + } + } + return first_key_value_arg_it; +} + void StorageS3Configuration::fromAST(ASTs & args, ContextPtr context, bool with_structure) { auto extra_credentials = extractExtraCredentials(args); size_t count = StorageURL::evalArgsAndCollectHeaders(args, headers_from_ast, context); + ASTs key_value_asts; + if (auto * first_key_value_arg_it = getFirstKeyValueArgument(args); + first_key_value_arg_it != args.end()) + { + key_value_asts = ASTs(first_key_value_arg_it, args.end()); + count -= key_value_asts.size(); + } + if (count == 0 || count > getMaxNumberOfArguments(with_structure)) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Storage S3 requires 1 to {} arguments. All supported signatures:\n{}", getMaxNumberOfArguments(with_structure), getSignatures(with_structure)); + auto key_value_args = parseKeyValueArguments(key_value_asts, context); + if (key_value_args.contains("structure")) + with_structure = false; + const auto & config = context->getConfigRef(); s3_capabilities = std::make_unique(getCapabilitiesFromConfig(config, "s3")); @@ -532,27 +617,35 @@ void StorageS3Configuration::fromAST(ASTs & args, ContextPtr context, bool with_ s3_settings->request_settings.updateIfChanged(endpoint_settings->request_settings); } - if (engine_args_to_idx.contains("format")) + if (auto format_value = getFromPositionOrKeyValue("format", args, engine_args_to_idx, key_value_args); + format_value.has_value()) { - auto format_ = checkAndGetLiteralArgument(args[engine_args_to_idx["format"]], "format"); + auto format_ = format_value.value(); /// Set format to configuration only of it's not 'auto', /// because we can have default format set in configuration. if (format_ != "auto") setFormat(format_); } - if (engine_args_to_idx.contains("structure")) - setStructure(checkAndGetLiteralArgument(args[engine_args_to_idx["structure"]], "structure")); + if (auto structure_value = getFromPositionOrKeyValue("structure", args, engine_args_to_idx, key_value_args); + structure_value.has_value()) + { + setStructure(structure_value.value()); + } - if (engine_args_to_idx.contains("compression_method")) - setCompressionMethod(checkAndGetLiteralArgument(args[engine_args_to_idx["compression_method"]], "compression_method")); + if (auto compression_method_value = getFromPositionOrKeyValue("compression_method", args, engine_args_to_idx, key_value_args); + compression_method_value.has_value()) + { + setCompressionMethod(compression_method_value.value()); + } - if (engine_args_to_idx.contains("partition_strategy")) + if (auto partition_strategy_value = getFromPositionOrKeyValue("partition_strategy", args, engine_args_to_idx, key_value_args); + partition_strategy_value.has_value()) { - const auto partition_strategy_name = checkAndGetLiteralArgument(args[engine_args_to_idx["partition_strategy"]], "partition_strategy"); + const auto & partition_strategy_name = partition_strategy_value.value(); const auto partition_strategy_type_opt = magic_enum::enum_cast(partition_strategy_name, magic_enum::case_insensitive); - if (!partition_strategy_type_opt) + if (!partition_strategy_type_opt.has_value()) { throw Exception(ErrorCodes::BAD_ARGUMENTS, "Partition strategy {} is not supported", partition_strategy_name); } @@ -560,22 +653,41 @@ void StorageS3Configuration::fromAST(ASTs & args, ContextPtr context, bool with_ setPartitionStrategyType(partition_strategy_type_opt.value()); } - if (engine_args_to_idx.contains("partition_columns_in_data_file")) - setPartitionColumnsInDataFile(checkAndGetLiteralArgument(args[engine_args_to_idx["partition_columns_in_data_file"]], "partition_columns_in_data_file")); + if (auto partition_columns_in_data_file_value = getFromPositionOrKeyValue("partition_columns_in_data_file", args, engine_args_to_idx, key_value_args); + partition_columns_in_data_file_value.has_value()) + { + setPartitionColumnsInDataFile(partition_columns_in_data_file_value.value()); + } else setPartitionColumnsInDataFile(getPartitionStrategyType() != PartitionStrategyFactory::StrategyType::HIVE); - if (engine_args_to_idx.contains("access_key_id")) - s3_settings->auth_settings[S3AuthSetting::access_key_id] = checkAndGetLiteralArgument(args[engine_args_to_idx["access_key_id"]], "access_key_id"); + if (auto access_key_id_value = getFromPositionOrKeyValue("access_key_id", args, engine_args_to_idx, key_value_args); + access_key_id_value.has_value()) + { + s3_settings->auth_settings[S3AuthSetting::access_key_id] = access_key_id_value.value(); + } - if (engine_args_to_idx.contains("secret_access_key")) - s3_settings->auth_settings[S3AuthSetting::secret_access_key] = checkAndGetLiteralArgument(args[engine_args_to_idx["secret_access_key"]], "secret_access_key"); + if (auto secret_access_key_value = getFromPositionOrKeyValue("secret_access_key", args, engine_args_to_idx, key_value_args); + secret_access_key_value.has_value()) + { + s3_settings->auth_settings[S3AuthSetting::secret_access_key] = secret_access_key_value.value(); + } - if (engine_args_to_idx.contains("session_token")) - s3_settings->auth_settings[S3AuthSetting::session_token] = checkAndGetLiteralArgument(args[engine_args_to_idx["session_token"]], "session_token"); + if (auto session_token_value = getFromPositionOrKeyValue("session_token", args, engine_args_to_idx, key_value_args); + session_token_value.has_value()) + { + s3_settings->auth_settings[S3AuthSetting::session_token] = session_token_value.value(); + } if (no_sign_request) + { s3_settings->auth_settings[S3AuthSetting::no_sign_request] = no_sign_request; + } + else if (auto no_sign_value = getFromPositionOrKeyValue("no_sign", args, {}, key_value_args); + no_sign_value.has_value()) + { + s3_settings->auth_settings[S3AuthSetting::no_sign_request] = no_sign_value.value(); + } static_configuration = !s3_settings->auth_settings[S3AuthSetting::access_key_id].value.empty() || s3_settings->auth_settings[S3AuthSetting::no_sign_request].changed; @@ -610,14 +722,87 @@ void StorageS3Configuration::addStructureAndFormatToArgsIfNeeded( auto extra_credentials = extractExtraCredentials(args); HTTPHeaderEntries tmp_headers; + size_t count = StorageURL::evalArgsAndCollectHeaders(args, tmp_headers, context); + ASTs key_value_asts; + auto * first_key_value_arg_it = getFirstKeyValueArgument(args); + if (first_key_value_arg_it != args.end()) + { + key_value_asts = ASTs(first_key_value_arg_it, args.end()); + count -= key_value_asts.size(); + } + if (count == 0 || count > getMaxNumberOfArguments()) - throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected 1 to {} arguments in table function s3, got {}", getMaxNumberOfArguments(), count); + { + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "Expected 1 to {} arguments in table function s3, got {}", + getMaxNumberOfArguments(), count); + } auto format_literal = std::make_shared(format_); auto structure_literal = std::make_shared(structure_); + bool format_in_key_value = false; + bool structure_in_key_value = false; + for (auto * it = first_key_value_arg_it; it != args.end(); ++it) + { + const auto & arg = *it; + const auto * function_ast = arg->as(); + if (!function_ast || function_ast->name != "equals") + continue; + + auto * args_expr = assert_cast(function_ast->arguments.get()); + auto & children = args_expr->children; + if (children.size() != 2) + { + throw Exception( + ErrorCodes::BAD_ARGUMENTS, + "Key value argument is incorrect: expected 2 arguments, got {}", + children.size()); + } + + auto literal = evaluateConstantExpressionOrIdentifierAsLiteral(children[0], context); + + auto arg_name_value = literal->as()->value; + if (arg_name_value.getType() != Field::Types::Which::String) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Expected string as credential name"); + auto arg_name = arg_name_value.safeGet(); + + if (arg_name == "format") + { + children[1] = format_literal; + format_in_key_value = true; + } + else if (arg_name == "structure") + { + children[1] = structure_literal; + structure_in_key_value = true; + } + } + + if (format_in_key_value && structure_in_key_value) + { + /// Add extracted extra credentials to the end of the args. + if (extra_credentials) + args.push_back(extra_credentials); + return; + } + else if (format_in_key_value && with_structure) + { + /// Structure goes right after format, so if format is in key-value, + /// then structure is required to be key-value. + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Expected positional arguments to go before key-value arguments"); + } + else if (structure_in_key_value) + { + with_structure = false; + } + + /// We will return it back at the end. + args.erase(first_key_value_arg_it, args.end()); + /// s3(s3_url) if (count == 1) { @@ -776,6 +961,9 @@ void StorageS3Configuration::addStructureAndFormatToArgsIfNeeded( args[5] = structure_literal; } + if (!key_value_asts.empty()) + args.insert(args.end(), std::make_move_iterator(key_value_asts.begin()), std::make_move_iterator(key_value_asts.end())); + /// Add extracted extra credentials to the end of the args. if (extra_credentials) args.push_back(extra_credentials); @@ -809,3 +997,4 @@ ASTPtr StorageS3Configuration::createArgsWithAccessData() const } #endif + diff --git a/tests/integration/test_mask_sensitive_info/test.py b/tests/integration/test_mask_sensitive_info/test.py index ad4962b6707f..4d7b6a879714 100644 --- a/tests/integration/test_mask_sensitive_info/test.py +++ b/tests/integration/test_mask_sensitive_info/test.py @@ -329,6 +329,7 @@ def test_create_table(): f"Kafka() SETTINGS kafka_broker_list = '127.0.0.1', kafka_topic_list = 'topic', kafka_group_name = 'group', kafka_format = 'JSONEachRow', kafka_security_protocol = 'sasl_ssl', kafka_sasl_mechanism = 'PLAIN', kafka_sasl_username = 'user', kafka_sasl_password = '{password}', format_avro_schema_registry_url = 'http://schema_user:{password}@'", f"Kafka() SETTINGS kafka_broker_list = '127.0.0.1', kafka_topic_list = 'topic', kafka_group_name = 'group', kafka_format = 'JSONEachRow', kafka_security_protocol = 'sasl_ssl', kafka_sasl_mechanism = 'PLAIN', kafka_sasl_username = 'user', kafka_sasl_password = '{password}', format_avro_schema_registry_url = 'http://schema_user:{password}@domain.com'", + f"S3('http://minio1:9001/root/data/test5.csv.gz', 'CSV', access_key_id = 'minio', secret_access_key = '{password}', compression_method = 'gzip')", ] def make_test_case(i): @@ -419,6 +420,7 @@ def make_test_case(i): "CREATE TABLE table44 (`x` int) ENGINE = Kafka SETTINGS kafka_broker_list = '127.0.0.1', kafka_topic_list = 'topic', kafka_group_name = 'group', kafka_format = 'JSONEachRow', kafka_security_protocol = 'sasl_ssl', kafka_sasl_mechanism = 'PLAIN', kafka_sasl_username = 'user', kafka_sasl_password = '[HIDDEN]', format_avro_schema_registry_url = 'http://schema_user:[HIDDEN]@'", "CREATE TABLE table45 (`x` int) ENGINE = Kafka SETTINGS kafka_broker_list = '127.0.0.1', kafka_topic_list = 'topic', kafka_group_name = 'group', kafka_format = 'JSONEachRow', kafka_security_protocol = 'sasl_ssl', kafka_sasl_mechanism = 'PLAIN', kafka_sasl_username = 'user', kafka_sasl_password = '[HIDDEN]', format_avro_schema_registry_url = 'http://schema_user:[HIDDEN]@domain.com'", + "CREATE TABLE table46 (`x` int) ENGINE = S3('http://minio1:9001/root/data/test5.csv.gz', 'CSV', access_key_id = 'minio', secret_access_key = '[HIDDEN]', compression_method = 'gzip')", ], must_not_contain=[password], ) diff --git a/tests/integration/test_storage_s3/test.py b/tests/integration/test_storage_s3/test.py index 0809e3f8c798..b44b2f807ebc 100644 --- a/tests/integration/test_storage_s3/test.py +++ b/tests/integration/test_storage_s3/test.py @@ -16,6 +16,8 @@ from helpers.network import PartitionManager from helpers.s3_tools import prepare_s3_bucket from helpers.test_tools import exec_query_with_retry +from helpers.config_cluster import minio_secret_key +from helpers.s3_queue_common import generate_random_string MINIO_INTERNAL_PORT = 9001 @@ -2520,3 +2522,91 @@ def test_filesystem_cache(started_cluster): f"SELECT ProfileEvents['S3GetObject'] FROM system.query_log WHERE query_id = '{query_id}' AND type = 'QueryFinish'" ) ) + + +def test_key_value_args(started_cluster): + node = started_cluster.instances["dummy"] + restricted_node = started_cluster.instances["restricted_dummy"] + table_name = f"test_key_value_args_{generate_random_string()}" + bucket = started_cluster.minio_bucket + + url = f"http://{started_cluster.minio_host}:{started_cluster.minio_port}/{bucket}/{table_name}_data" + + # Check format. + assert ( + "The data format cannot be detected by the contents" + in node.query_and_get_error( + f"insert into function s3('{url}', structure = 'a Int32, b String') select number, toString(number) from numbers(10) settings s3_truncate_on_insert=1" + ) + ) + node.query( + f"insert into function s3('{url}', format = TSVRaw, structure = 'a Int32, b String') select number, toString(number) from numbers(10) settings s3_truncate_on_insert=1" + ) + node.query( + f"insert into function s3('{url}', 'TSVRaw', structure = 'a Int32, b String') select number, toString(number) from numbers(10) settings s3_truncate_on_insert=1" + ) + + # Check structure. + assert ( + "a\tInt32\t\t\t\t\t\nb\tString" + in node.query( + f"describe table s3('{url}', format = TSVRaw, structure = 'a Int32, b String')" + ).strip() + ) + assert 2 == int( + node.query( + f"select a from s3('{url}', format = TSVRaw, structure = 'a Int32, b String') where b = '2'" + ) + ) + + # Check access_key_id, secret_access_key + assert ( + "The request signature we calculated does not match the signature you provided" + in node.query_and_get_error( + f"insert into function s3('{url}', structure = 'a Int32, b String', access_key_id = 'minio', secret_access_key = 'keko', format = 'TSVRaw') select number, toString(number) from numbers(10) settings s3_truncate_on_insert=1" + ) + ) + node.query( + f"insert into function s3('{url}', structure = 'a Int32, b String', access_key_id = 'minio', secret_access_key = '{minio_secret_key}', format = 'TSVRaw') select number, toString(number) from numbers(10) settings s3_truncate_on_insert=1" + ) + assert 2 == int( + node.query( + f"select a from s3('{url}', format = TSVRaw, access_key_id = 'minio', secret_access_key = '{minio_secret_key}', structure = 'a Int32, b String') where b = '2'" + ) + ) + + # Check session_token + assert "Failed to get object info" in node.query_and_get_error( + f"select a from s3('{url}', format = TSVRaw, access_key_id = 'minio', secret_access_key = '{minio_secret_key}', structure = 'a Int32, b String', session_token = 'kek') where b = '2'" + ) + + # Check structure + assert "Cannot parse DateTime" in node.query_and_get_error( + f"select a from s3('{url}', format = TSVRaw, access_key_id = 'minio', secret_access_key = '{minio_secret_key}', structure = 'a Int32, b DateTime') where b = '2'" + ) + + # Check compression_method + assert "inflate failed" in node.query_and_get_error( + f"select a from s3('{url}', format = TSVRaw, structure = 'a Int32, b String', access_key_id = 'minio', secret_access_key = '{minio_secret_key}', compression_method = 'gzip') where b = '2'" + ) + assert 2 == int( + node.query( + f"select a from s3('{url}', format = TSVRaw, structure = 'a Int32, b String', access_key_id = 'minio', secret_access_key = '{minio_secret_key}', compression_method = 'none') where b = '2'" + ) + ) + + # Check partition strategy + assert "is not supported" in node.query_and_get_error( + f"insert into function s3('{url}', format = TSVRaw, structure = 'a Int32, b String', partition_strategy='hivy') partition by b select number, toString(number) from numbers(10) settings s3_truncate_on_insert=1" + ) + node.query( + f"insert into function s3('{url}', format = TSVRaw, structure = 'a Int32, b String', partition_strategy='hive') partition by b select number, toString(number) from numbers(10) settings s3_truncate_on_insert=1" + ) + + # Check order - positional args before key-value + assert ( + "Expected positional arguments to go before key-value arguments" + in node.query_and_get_error( + f"insert into function s3('{url}', structure = 'a Int32, b String', partition_strategy='hive', TSV) partition by b select number, toString(number) from numbers(10) settings s3_truncate_on_insert=1" + ) + ) diff --git a/tests/integration/test_storage_s3/test_invalid_env_credentials.py b/tests/integration/test_storage_s3/test_invalid_env_credentials.py index 53f9a6968de6..933f5a359d4a 100644 --- a/tests/integration/test_storage_s3/test_invalid_env_credentials.py +++ b/tests/integration/test_storage_s3/test_invalid_env_credentials.py @@ -118,6 +118,7 @@ def test_with_invalid_environment_credentials(started_cluster): for bucket, auth in [ (started_cluster.minio_restricted_bucket, f"'minio', '{minio_secret_key}'"), (started_cluster.minio_bucket, "NOSIGN"), + (started_cluster.minio_bucket, "no_sign = 1"), ]: instance.query( f"insert into function s3('http://{started_cluster.minio_host}:{started_cluster.minio_port}/{bucket}/test_cache4.jsonl', {auth}) select * from numbers(100) settings s3_truncate_on_insert=1"