diff --git a/compose/config/config.py b/compose/config/config.py index 055ae18aca2..a00032d4406 100644 --- a/compose/config/config.py +++ b/compose/config/config.py @@ -30,6 +30,7 @@ from .types import ServiceLink from .types import VolumeFromSpec from .types import VolumeSpec +from .validation import load_jsonschema from .validation import match_named_volumes from .validation import validate_against_fields_schema from .validation import validate_against_service_schema @@ -388,16 +389,19 @@ def merge_services(base, override): return build_services(service_config) -def interpolate_config_section(filename, config, section): +def interpolate_config_section(filename, config, section, jsonschema): validate_config_section(filename, config, section) - return interpolate_environment_variables(config, section) + return interpolate_environment_variables(config, section, jsonschema) def process_config_file(config_file, service_name=None): + jsonschema = load_jsonschema('service', config_file.version) + services = interpolate_config_section( config_file.filename, config_file.get_service_dicts(), - 'service') + 'service', + jsonschema) if config_file.version == V2_0: processed_config = dict(config_file.config) @@ -405,11 +409,13 @@ def process_config_file(config_file, service_name=None): processed_config['volumes'] = interpolate_config_section( config_file.filename, config_file.get_volumes(), - 'volume') + 'volume', + jsonschema) processed_config['networks'] = interpolate_config_section( config_file.filename, config_file.get_networks(), - 'network') + 'network', + jsonschema) if config_file.version == V1: processed_config = services diff --git a/compose/config/interpolation.py b/compose/config/interpolation.py index 1e56ebb6685..65fd8e9837a 100644 --- a/compose/config/interpolation.py +++ b/compose/config/interpolation.py @@ -1,6 +1,7 @@ from __future__ import absolute_import from __future__ import unicode_literals +import json import logging import os from string import Template @@ -8,15 +9,17 @@ import six from .errors import ConfigurationError + + log = logging.getLogger(__name__) -def interpolate_environment_variables(config, section): +def interpolate_environment_variables(config, section, schema): mapping = BlankDefaultDict(os.environ) def process_item(name, config_dict): return dict( - (key, interpolate_value(name, key, val, section, mapping)) + (key, interpolate_value(name, key, val, section, mapping, schema)) for key, val in (config_dict or {}).items() ) @@ -26,9 +29,34 @@ def process_item(name, config_dict): ) -def interpolate_value(name, config_key, value, section, mapping): +def _cast_interpolated_inside_list(interpolated_value, possible_types): + castable_types = [t for t in possible_types if t != "string"] + for possible_type in castable_types: + v = _cast_interpolated(interpolated_value, possible_type) + if v is not None: + return v + return interpolated_value + + +def _cast_interpolated(interpolated_value, field_def): + field_type = field_def + if "type" in field_def: + field_type = field_def["type"] + if isinstance(field_type, list): + return _cast_interpolated_inside_list(interpolated_value, field_type) + if field_type == "array" and isinstance(interpolated_value, list): + return [_cast_interpolated(subfield, field_def["items"]["type"]) for subfield in interpolated_value] + try: + return json.loads(interpolated_value.replace("'", "\"")) + except ValueError: + return interpolated_value + + +def interpolate_value(name, config_key, value, section, mapping, schema): try: - return recursive_interpolate(value, mapping) + properties_schema = schema["definitions"]["service"]["properties"] + field_def = properties_schema.get(config_key, None) + return recursive_interpolate(value, mapping, field_def) except InvalidInterpolation as e: raise ConfigurationError( 'Invalid interpolation format for "{config_key}" option ' @@ -39,21 +67,42 @@ def interpolate_value(name, config_key, value, section, mapping): string=e.string)) -def recursive_interpolate(obj, mapping): +def _get_field_types(field_def): + if "type" in field_def: + return field_def["type"] + if "oneOf" in field_def: + return [_get_field_types(f_def) for f_def in field_def["oneOf"]] + + +def _get_sub_field_def(field_def, name): + if field_def is None: + return None + for option_def in field_def.get("oneOf", []): + if name in option_def.get("properties", []): + return option_def["properties"][name] + return field_def.get(name, None) + + +def recursive_interpolate(obj, mapping, field_def): if isinstance(obj, six.string_types): - return interpolate(obj, mapping) + value = interpolate(obj, mapping) + if value != obj and field_def is not None: + return _cast_interpolated(value, field_def) + return value elif isinstance(obj, dict): return dict( - (key, recursive_interpolate(val, mapping)) + (key, recursive_interpolate(val, mapping, _get_sub_field_def(field_def, key))) for (key, val) in obj.items() ) elif isinstance(obj, list): - return [recursive_interpolate(val, mapping) for val in obj] + return [recursive_interpolate(val, mapping, _get_sub_field_def(field_def, "items")) for val in obj] else: return obj def interpolate(string, mapping): + if '$' not in string: + return string try: return Template(string).substitute(mapping) except ValueError: diff --git a/compose/config/validation.py b/compose/config/validation.py index 4e2083cbc29..96ea1172bad 100644 --- a/compose/config/validation.py +++ b/compose/config/validation.py @@ -386,10 +386,10 @@ def format_error_message(error): def validate_against_fields_schema(config_file): - schema_filename = "fields_schema_v{0}.json".format(config_file.version) + schema = load_jsonschema("fields", config_file.version) _validate_against_schema( config_file.config, - schema_filename, + schema, format_checker=["ports", "expose", "bool-value-in-mapping"], filename=config_file.filename) @@ -397,19 +397,33 @@ def validate_against_fields_schema(config_file): def validate_against_service_schema(config, service_name, version): _validate_against_schema( config, - "service_schema_v{0}.json".format(version), + load_jsonschema("service", version), format_checker=["ports"], path_prefix=[service_name]) +def load_jsonschema(prefix, version): + schema_filename = "{0}_schema_v{1}.json".format(prefix, version) + config_source_dir = os.path.dirname(os.path.abspath(__file__)) + + if sys.platform == "win32": + config_source_dir = config_source_dir.replace('\\', '/') + + schema_file = os.path.join(config_source_dir, schema_filename) + + with open(schema_file, "r") as schema_fh: + schema = json.load(schema_fh) + return schema + + def _validate_against_schema( config, - schema_filename, + schema, format_checker=(), path_prefix=None, filename=None): - config_source_dir = os.path.dirname(os.path.abspath(__file__)) + config_source_dir = os.path.dirname(os.path.abspath(__file__)) if sys.platform == "win32": file_pre_fix = "///" config_source_dir = config_source_dir.replace('\\', '/') @@ -417,10 +431,6 @@ def _validate_against_schema( file_pre_fix = "//" resolver_full_path = "file:{}{}/".format(file_pre_fix, config_source_dir) - schema_file = os.path.join(config_source_dir, schema_filename) - - with open(schema_file, "r") as schema_fh: - schema = json.load(schema_fh) resolver = RefResolver(resolver_full_path, schema) validation_output = Draft4Validator( diff --git a/tests/unit/config/interpolation_test.py b/tests/unit/config/interpolation_test.py index 0691e88652f..168829d577b 100644 --- a/tests/unit/config/interpolation_test.py +++ b/tests/unit/config/interpolation_test.py @@ -7,6 +7,8 @@ import pytest from compose.config.interpolation import interpolate_environment_variables +from compose.config.validation import load_jsonschema +from compose.const import COMPOSEFILE_V2_0 as V2_0 @pytest.yield_fixture @@ -14,6 +16,12 @@ def mock_env(): with mock.patch.dict(os.environ): os.environ['USER'] = 'jenny' os.environ['FOO'] = 'bar' + os.environ['RO'] = 'false' + os.environ['VOLUMES'] = "['/tmp/bar:/bar', '/tmp/foo:/foo']" + os.environ['VOLUMES_ITEM'] = '/tmp/foo:/foo' + os.environ['CPU_SHARES'] = '512' + os.environ['PORTS'] = "[ 8080 , 8089 ]" + os.environ['SINGLE_PORT'] = '8787' yield @@ -21,7 +29,10 @@ def test_interpolate_environment_variables_in_services(mock_env): services = { 'servivea': { 'image': 'example:${USER}', + 'entrypoint': '/bin/bash', 'volumes': ['$FOO:/target'], + 'read_only': '${RO}', + 'cpu_shares': '${CPU_SHARES}', 'logging': { 'driver': '${FOO}', 'options': { @@ -33,7 +44,10 @@ def test_interpolate_environment_variables_in_services(mock_env): expected = { 'servivea': { 'image': 'example:jenny', + 'entrypoint': '/bin/bash', 'volumes': ['bar:/target'], + 'read_only': False, + 'cpu_shares': 512, 'logging': { 'driver': 'bar', 'options': { @@ -42,7 +56,59 @@ def test_interpolate_environment_variables_in_services(mock_env): } } } - assert interpolate_environment_variables(services, 'service') == expected + assert interpolate_environment_variables(services, 'service', load_jsonschema('service', V2_0)) == expected + + +def test_interpolate_environment_variables_arrays_in_services(mock_env): + services = { + 'servivea': { + 'image': 'example:${USER}', + 'volumes': '${VOLUMES}', + 'cpu_shares': '${CPU_SHARES}', + } + } + expected = { + 'servivea': { + 'image': 'example:jenny', + 'volumes': ['/tmp/bar:/bar', '/tmp/foo:/foo'], + 'cpu_shares': 512, + } + } + assert interpolate_environment_variables(services, 'service', load_jsonschema('service', V2_0)) == expected + + +def test_interpolate_environment_variables_array_element_in_services(mock_env): + services = { + 'servivea': { + 'image': 'example:${USER}', + 'volumes': ['/tmp/bar:/bar', '${VOLUMES_ITEM}'], + 'cpu_shares': '${CPU_SHARES}', + } + } + expected = { + 'servivea': { + 'image': 'example:jenny', + 'volumes': ['/tmp/bar:/bar', '/tmp/foo:/foo'], + 'cpu_shares': 512, + } + } + assert interpolate_environment_variables(services, 'service', load_jsonschema('service', V2_0)) == expected + + +def test_interpolate_environment_variables_array_numbers_in_services(mock_env): + services = { + 'servivea': { + 'expose': '${PORTS}', + 'ports': [8080, '${SINGLE_PORT}'] + } + } + expected = { + 'servivea': { + 'expose': [8080, 8089], + 'ports': [8080, 8787] + } + } + assert interpolate_environment_variables(services, 'service', load_jsonschema('service', V2_0)) == expected def test_interpolate_environment_variables_in_volumes(mock_env): @@ -66,4 +132,4 @@ def test_interpolate_environment_variables_in_volumes(mock_env): }, 'other': {}, } - assert interpolate_environment_variables(volumes, 'volume') == expected + assert interpolate_environment_variables(volumes, 'volume', load_jsonschema('service', V2_0)) == expected