Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/Parsers/FunctionSecretArgumentsFinder.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ class FunctionSecretArgumentsFinder
break;
if (f->name() == "headers")
result.nested_maps.push_back(f->name());
else if (f->name() != "extra_credentials")
else if (f->name() != "extra_credentials" && f->name() != "equals")
break;
count -= 1;
}
Expand All @@ -266,6 +266,8 @@ class FunctionSecretArgumentsFinder
return;
}

findSecretNamedArgument("secret_access_key", url_arg_idx);

/// We should check other arguments first because we don't need to do any replacement in case of
/// s3('url', NOSIGN, 'format' [, 'compression'] [, extra_credentials(..)] [, headers(..)])
/// s3('url', 'format', 'structure' [, 'compression'] [, extra_credentials(..)] [, headers(..)])
Expand Down Expand Up @@ -625,6 +627,8 @@ class FunctionSecretArgumentsFinder
return;
}

findSecretNamedArgument("secret_access_key", 0);

/// We should check other arguments first because we don't need to do any replacement in case of
/// S3('url', NOSIGN, 'format' [, 'compression'] [, extra_credentials(..)] [, headers(..)])
/// S3('url', 'format', 'compression' [, extra_credentials(..)] [, headers(..)])
Expand Down
225 changes: 207 additions & 18 deletions src/Storages/ObjectStorage/S3/Configuration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,18 +294,103 @@ bool StorageS3Configuration::collectCredentials(ASTPtr maybe_credentials, S3::S3
return true;
}

template <typename T>
static std::optional<T> getFromPositionOrKeyValue(
const std::string & key,
const ASTs & args,
const std::unordered_map<std::string_view, size_t> & engine_args_to_idx,
const std::unordered_map<std::string, Field> & key_value_args)
{
if (auto arg_it = engine_args_to_idx.find(key); arg_it != engine_args_to_idx.end())
return checkAndGetLiteralArgument<T>(args[arg_it->second], key);

if (auto arg_it = key_value_args.find(key); arg_it != key_value_args.end())
return arg_it->second.safeGet<T>();

return std::nullopt;
};

static std::unordered_map<std::string, Field> parseKeyValueArguments(const ASTs & function_args, ContextPtr context)
{
std::unordered_map<std::string, Field> key_value_args;
for (const auto & arg : function_args)
{
const auto * function_ast = arg->as<ASTFunction>();
if (!function_ast || function_ast->name != "equals")
continue;

auto * args_expr = assert_cast<ASTExpressionList *>(function_ast->arguments.get());
auto & children = args_expr->children;
if (children.size() != 2)
{
throw Exception(
ErrorCodes::BAD_ARGUMENTS,
"Key value argument is incorrect: expected 2 arguments, got {}",
children.size());
}

auto key_literal = evaluateConstantExpressionOrIdentifierAsLiteral(children[0], context);
auto value_literal = evaluateConstantExpressionOrIdentifierAsLiteral(children[1], context);

auto arg_name_value = key_literal->as<ASTLiteral>()->value;
if (arg_name_value.getType() != Field::Types::Which::String)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Expected string as credential name");

auto arg_name = arg_name_value.safeGet<String>();
auto arg_value = value_literal->as<ASTLiteral>()->value;

auto inserted = key_value_args.emplace(arg_name, arg_value).second;
if (!inserted)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Duplicate key value argument: {}", arg_name);
}
return key_value_args;
}

static ASTs::iterator getFirstKeyValueArgument(ASTs & args)
{
ASTs::iterator first_key_value_arg_it = args.end();
for (auto * it = args.begin(); it != args.end(); ++it)
{
const auto * function_ast = (*it)->as<ASTFunction>();
if (function_ast && function_ast->name == "equals")
{
if (first_key_value_arg_it == args.end())
first_key_value_arg_it = it;
}
else if (first_key_value_arg_it != args.end())
{
throw Exception(
ErrorCodes::BAD_ARGUMENTS,
"Expected positional arguments to go before key-value arguments");
}
}
return first_key_value_arg_it;
}

void StorageS3Configuration::fromAST(ASTs & args, ContextPtr context, bool with_structure)
{
auto extra_credentials = extractExtraCredentials(args);

size_t count = StorageURL::evalArgsAndCollectHeaders(args, headers_from_ast, context);

ASTs key_value_asts;
if (auto * first_key_value_arg_it = getFirstKeyValueArgument(args);
first_key_value_arg_it != args.end())
{
key_value_asts = ASTs(first_key_value_arg_it, args.end());
count -= key_value_asts.size();
}

if (count == 0 || count > getMaxNumberOfArguments(with_structure))
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Storage S3 requires 1 to {} arguments. All supported signatures:\n{}",
getMaxNumberOfArguments(with_structure),
getSignatures(with_structure));

auto key_value_args = parseKeyValueArguments(key_value_asts, context);
if (key_value_args.contains("structure"))
with_structure = false;

const auto & config = context->getConfigRef();
s3_capabilities = std::make_unique<S3Capabilities>(getCapabilitiesFromConfig(config, "s3"));

Expand Down Expand Up @@ -532,50 +617,77 @@ void StorageS3Configuration::fromAST(ASTs & args, ContextPtr context, bool with_
s3_settings->request_settings.updateIfChanged(endpoint_settings->request_settings);
}

if (engine_args_to_idx.contains("format"))
if (auto format_value = getFromPositionOrKeyValue<String>("format", args, engine_args_to_idx, key_value_args);
format_value.has_value())
{
auto format_ = checkAndGetLiteralArgument<String>(args[engine_args_to_idx["format"]], "format");
auto format_ = format_value.value();
/// Set format to configuration only of it's not 'auto',
/// because we can have default format set in configuration.
if (format_ != "auto")
setFormat(format_);
}

if (engine_args_to_idx.contains("structure"))
setStructure(checkAndGetLiteralArgument<String>(args[engine_args_to_idx["structure"]], "structure"));
if (auto structure_value = getFromPositionOrKeyValue<String>("structure", args, engine_args_to_idx, key_value_args);
structure_value.has_value())
{
setStructure(structure_value.value());
}

if (engine_args_to_idx.contains("compression_method"))
setCompressionMethod(checkAndGetLiteralArgument<String>(args[engine_args_to_idx["compression_method"]], "compression_method"));
if (auto compression_method_value = getFromPositionOrKeyValue<String>("compression_method", args, engine_args_to_idx, key_value_args);
compression_method_value.has_value())
{
setCompressionMethod(compression_method_value.value());
}

if (engine_args_to_idx.contains("partition_strategy"))
if (auto partition_strategy_value = getFromPositionOrKeyValue<String>("partition_strategy", args, engine_args_to_idx, key_value_args);
partition_strategy_value.has_value())
{
const auto partition_strategy_name = checkAndGetLiteralArgument<String>(args[engine_args_to_idx["partition_strategy"]], "partition_strategy");
const auto & partition_strategy_name = partition_strategy_value.value();
const auto partition_strategy_type_opt = magic_enum::enum_cast<PartitionStrategyFactory::StrategyType>(partition_strategy_name, magic_enum::case_insensitive);

if (!partition_strategy_type_opt)
if (!partition_strategy_type_opt.has_value())
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Partition strategy {} is not supported", partition_strategy_name);
}

setPartitionStrategyType(partition_strategy_type_opt.value());
}

if (engine_args_to_idx.contains("partition_columns_in_data_file"))
setPartitionColumnsInDataFile(checkAndGetLiteralArgument<bool>(args[engine_args_to_idx["partition_columns_in_data_file"]], "partition_columns_in_data_file"));
if (auto partition_columns_in_data_file_value = getFromPositionOrKeyValue<bool>("partition_columns_in_data_file", args, engine_args_to_idx, key_value_args);
partition_columns_in_data_file_value.has_value())
{
setPartitionColumnsInDataFile(partition_columns_in_data_file_value.value());
}
else
setPartitionColumnsInDataFile(getPartitionStrategyType() != PartitionStrategyFactory::StrategyType::HIVE);

if (engine_args_to_idx.contains("access_key_id"))
s3_settings->auth_settings[S3AuthSetting::access_key_id] = checkAndGetLiteralArgument<String>(args[engine_args_to_idx["access_key_id"]], "access_key_id");
if (auto access_key_id_value = getFromPositionOrKeyValue<String>("access_key_id", args, engine_args_to_idx, key_value_args);
access_key_id_value.has_value())
{
s3_settings->auth_settings[S3AuthSetting::access_key_id] = access_key_id_value.value();
}

if (engine_args_to_idx.contains("secret_access_key"))
s3_settings->auth_settings[S3AuthSetting::secret_access_key] = checkAndGetLiteralArgument<String>(args[engine_args_to_idx["secret_access_key"]], "secret_access_key");
if (auto secret_access_key_value = getFromPositionOrKeyValue<String>("secret_access_key", args, engine_args_to_idx, key_value_args);
secret_access_key_value.has_value())
{
s3_settings->auth_settings[S3AuthSetting::secret_access_key] = secret_access_key_value.value();
}

if (engine_args_to_idx.contains("session_token"))
s3_settings->auth_settings[S3AuthSetting::session_token] = checkAndGetLiteralArgument<String>(args[engine_args_to_idx["session_token"]], "session_token");
if (auto session_token_value = getFromPositionOrKeyValue<String>("session_token", args, engine_args_to_idx, key_value_args);
session_token_value.has_value())
{
s3_settings->auth_settings[S3AuthSetting::session_token] = session_token_value.value();
}

if (no_sign_request)
{
s3_settings->auth_settings[S3AuthSetting::no_sign_request] = no_sign_request;
}
else if (auto no_sign_value = getFromPositionOrKeyValue<bool>("no_sign", args, {}, key_value_args);
no_sign_value.has_value())
{
s3_settings->auth_settings[S3AuthSetting::no_sign_request] = no_sign_value.value();
}

static_configuration = !s3_settings->auth_settings[S3AuthSetting::access_key_id].value.empty() || s3_settings->auth_settings[S3AuthSetting::no_sign_request].changed;

Expand Down Expand Up @@ -610,14 +722,87 @@ void StorageS3Configuration::addStructureAndFormatToArgsIfNeeded(
auto extra_credentials = extractExtraCredentials(args);

HTTPHeaderEntries tmp_headers;

size_t count = StorageURL::evalArgsAndCollectHeaders(args, tmp_headers, context);

ASTs key_value_asts;
auto * first_key_value_arg_it = getFirstKeyValueArgument(args);
if (first_key_value_arg_it != args.end())
{
key_value_asts = ASTs(first_key_value_arg_it, args.end());
count -= key_value_asts.size();
}

if (count == 0 || count > getMaxNumberOfArguments())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected 1 to {} arguments in table function s3, got {}", getMaxNumberOfArguments(), count);
{
throw Exception(
ErrorCodes::LOGICAL_ERROR,
"Expected 1 to {} arguments in table function s3, got {}",
getMaxNumberOfArguments(), count);
}

auto format_literal = std::make_shared<ASTLiteral>(format_);
auto structure_literal = std::make_shared<ASTLiteral>(structure_);

bool format_in_key_value = false;
bool structure_in_key_value = false;
for (auto * it = first_key_value_arg_it; it != args.end(); ++it)
{
const auto & arg = *it;
const auto * function_ast = arg->as<ASTFunction>();
if (!function_ast || function_ast->name != "equals")
continue;

auto * args_expr = assert_cast<ASTExpressionList *>(function_ast->arguments.get());
auto & children = args_expr->children;
if (children.size() != 2)
{
throw Exception(
ErrorCodes::BAD_ARGUMENTS,
"Key value argument is incorrect: expected 2 arguments, got {}",
children.size());
}

auto literal = evaluateConstantExpressionOrIdentifierAsLiteral(children[0], context);

auto arg_name_value = literal->as<ASTLiteral>()->value;
if (arg_name_value.getType() != Field::Types::Which::String)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Expected string as credential name");
auto arg_name = arg_name_value.safeGet<String>();

if (arg_name == "format")
{
children[1] = format_literal;
format_in_key_value = true;
}
else if (arg_name == "structure")
{
children[1] = structure_literal;
structure_in_key_value = true;
}
}

if (format_in_key_value && structure_in_key_value)
{
/// Add extracted extra credentials to the end of the args.
if (extra_credentials)
args.push_back(extra_credentials);
return;
}
else if (format_in_key_value && with_structure)
{
/// Structure goes right after format, so if format is in key-value,
/// then structure is required to be key-value.
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Expected positional arguments to go before key-value arguments");
}
else if (structure_in_key_value)
{
with_structure = false;
}

/// We will return it back at the end.
args.erase(first_key_value_arg_it, args.end());

/// s3(s3_url)
if (count == 1)
{
Expand Down Expand Up @@ -776,6 +961,9 @@ void StorageS3Configuration::addStructureAndFormatToArgsIfNeeded(
args[5] = structure_literal;
}

if (!key_value_asts.empty())
args.insert(args.end(), std::make_move_iterator(key_value_asts.begin()), std::make_move_iterator(key_value_asts.end()));

/// Add extracted extra credentials to the end of the args.
if (extra_credentials)
args.push_back(extra_credentials);
Expand Down Expand Up @@ -809,3 +997,4 @@ ASTPtr StorageS3Configuration::createArgsWithAccessData() const
}

#endif

2 changes: 2 additions & 0 deletions tests/integration/test_mask_sensitive_info/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def test_create_table():

f"Kafka() SETTINGS kafka_broker_list = '127.0.0.1', kafka_topic_list = 'topic', kafka_group_name = 'group', kafka_format = 'JSONEachRow', kafka_security_protocol = 'sasl_ssl', kafka_sasl_mechanism = 'PLAIN', kafka_sasl_username = 'user', kafka_sasl_password = '{password}', format_avro_schema_registry_url = 'http://schema_user:{password}@'",
f"Kafka() SETTINGS kafka_broker_list = '127.0.0.1', kafka_topic_list = 'topic', kafka_group_name = 'group', kafka_format = 'JSONEachRow', kafka_security_protocol = 'sasl_ssl', kafka_sasl_mechanism = 'PLAIN', kafka_sasl_username = 'user', kafka_sasl_password = '{password}', format_avro_schema_registry_url = 'http://schema_user:{password}@domain.com'",
f"S3('http://minio1:9001/root/data/test5.csv.gz', 'CSV', access_key_id = 'minio', secret_access_key = '{password}', compression_method = 'gzip')",
]

def make_test_case(i):
Expand Down Expand Up @@ -419,6 +420,7 @@ def make_test_case(i):

"CREATE TABLE table44 (`x` int) ENGINE = Kafka SETTINGS kafka_broker_list = '127.0.0.1', kafka_topic_list = 'topic', kafka_group_name = 'group', kafka_format = 'JSONEachRow', kafka_security_protocol = 'sasl_ssl', kafka_sasl_mechanism = 'PLAIN', kafka_sasl_username = 'user', kafka_sasl_password = '[HIDDEN]', format_avro_schema_registry_url = 'http://schema_user:[HIDDEN]@'",
"CREATE TABLE table45 (`x` int) ENGINE = Kafka SETTINGS kafka_broker_list = '127.0.0.1', kafka_topic_list = 'topic', kafka_group_name = 'group', kafka_format = 'JSONEachRow', kafka_security_protocol = 'sasl_ssl', kafka_sasl_mechanism = 'PLAIN', kafka_sasl_username = 'user', kafka_sasl_password = '[HIDDEN]', format_avro_schema_registry_url = 'http://schema_user:[HIDDEN]@domain.com'",
"CREATE TABLE table46 (`x` int) ENGINE = S3('http://minio1:9001/root/data/test5.csv.gz', 'CSV', access_key_id = 'minio', secret_access_key = '[HIDDEN]', compression_method = 'gzip')",
],
must_not_contain=[password],
)
Expand Down
Loading
Loading