Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions compose/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .types import ServiceLink
from .types import VolumeFromSpec
from .types import VolumeSpec
from .validation import load_jsonschema
from .validation import match_named_volumes
from .validation import validate_against_fields_schema
from .validation import validate_against_service_schema
Expand Down Expand Up @@ -388,28 +389,33 @@ def merge_services(base, override):
return build_services(service_config)


def interpolate_config_section(filename, config, section):
def interpolate_config_section(filename, config, section, jsonschema):
validate_config_section(filename, config, section)
return interpolate_environment_variables(config, section)
return interpolate_environment_variables(config, section, jsonschema)


def process_config_file(config_file, service_name=None):
jsonschema = load_jsonschema('service', config_file.version)

services = interpolate_config_section(
config_file.filename,
config_file.get_service_dicts(),
'service')
'service',
jsonschema)

if config_file.version == V2_0:
processed_config = dict(config_file.config)
processed_config['services'] = services
processed_config['volumes'] = interpolate_config_section(
config_file.filename,
config_file.get_volumes(),
'volume')
'volume',
jsonschema)
processed_config['networks'] = interpolate_config_section(
config_file.filename,
config_file.get_networks(),
'network')
'network',
jsonschema)

if config_file.version == V1:
processed_config = services
Expand Down
65 changes: 57 additions & 8 deletions compose/config/interpolation.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
from __future__ import absolute_import
from __future__ import unicode_literals

import json
import logging
import os
from string import Template

import six

from .errors import ConfigurationError


log = logging.getLogger(__name__)


def interpolate_environment_variables(config, section):
def interpolate_environment_variables(config, section, schema):
mapping = BlankDefaultDict(os.environ)

def process_item(name, config_dict):
return dict(
(key, interpolate_value(name, key, val, section, mapping))
(key, interpolate_value(name, key, val, section, mapping, schema))
for key, val in (config_dict or {}).items()
)

Expand All @@ -26,9 +29,34 @@ def process_item(name, config_dict):
)


def interpolate_value(name, config_key, value, section, mapping):
def _cast_interpolated_inside_list(interpolated_value, possible_types):
castable_types = [t for t in possible_types if t != "string"]
for possible_type in castable_types:
v = _cast_interpolated(interpolated_value, possible_type)
if v is not None:
return v
return interpolated_value


def _cast_interpolated(interpolated_value, field_def):
field_type = field_def
if "type" in field_def:
field_type = field_def["type"]
if isinstance(field_type, list):
return _cast_interpolated_inside_list(interpolated_value, field_type)
if field_type == "array" and isinstance(interpolated_value, list):
return [_cast_interpolated(subfield, field_def["items"]["type"]) for subfield in interpolated_value]
try:
return json.loads(interpolated_value.replace("'", "\""))
except ValueError:
return interpolated_value

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems overly complex. I think supporting arrays is probably the most contentious part of this feature. I think we should try to keep it as straightforward as possible.

How about:

return json.loads(interpolated_value)

Although, I wonder if it should actually be yaml.safe_load() since the config is yaml format. Using yaml would support json as well.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I said I have not written python for years, but it seems that json.loads does not work with an Array.
Looks like I am not the only one to face the issue
=> http://stackoverflow.com/questions/10973614/convert-json-array-to-python-list

However, it may be just because of issue between simple quotes (') and double quotes (")

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's it. json only supports double quotes for strings.


def interpolate_value(name, config_key, value, section, mapping, schema):
try:
return recursive_interpolate(value, mapping)
properties_schema = schema["definitions"]["service"]["properties"]
field_def = properties_schema.get(config_key, None)
return recursive_interpolate(value, mapping, field_def)
except InvalidInterpolation as e:
raise ConfigurationError(
'Invalid interpolation format for "{config_key}" option '
Expand All @@ -39,21 +67,42 @@ def interpolate_value(name, config_key, value, section, mapping):
string=e.string))


def recursive_interpolate(obj, mapping):
def _get_field_types(field_def):
if "type" in field_def:
return field_def["type"]
if "oneOf" in field_def:
return [_get_field_types(f_def) for f_def in field_def["oneOf"]]


def _get_sub_field_def(field_def, name):
if field_def is None:
return None
for option_def in field_def.get("oneOf", []):
if name in option_def.get("properties", []):
return option_def["properties"][name]
return field_def.get(name, None)


def recursive_interpolate(obj, mapping, field_def):
if isinstance(obj, six.string_types):
return interpolate(obj, mapping)
value = interpolate(obj, mapping)
if value != obj and field_def is not None:
return _cast_interpolated(value, field_def)
return value
elif isinstance(obj, dict):
return dict(
(key, recursive_interpolate(val, mapping))
(key, recursive_interpolate(val, mapping, _get_sub_field_def(field_def, key)))
for (key, val) in obj.items()
)
elif isinstance(obj, list):
return [recursive_interpolate(val, mapping) for val in obj]
return [recursive_interpolate(val, mapping, _get_sub_field_def(field_def, "items")) for val in obj]
else:
return obj


def interpolate(string, mapping):
if '$' not in string:
return string
try:
return Template(string).substitute(mapping)
except ValueError:
Expand Down
28 changes: 19 additions & 9 deletions compose/config/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,41 +386,51 @@ def format_error_message(error):


def validate_against_fields_schema(config_file):
schema_filename = "fields_schema_v{0}.json".format(config_file.version)
schema = load_jsonschema("fields", config_file.version)
_validate_against_schema(
config_file.config,
schema_filename,
schema,
format_checker=["ports", "expose", "bool-value-in-mapping"],
filename=config_file.filename)


def validate_against_service_schema(config, service_name, version):
_validate_against_schema(
config,
"service_schema_v{0}.json".format(version),
load_jsonschema("service", version),
format_checker=["ports"],
path_prefix=[service_name])


def load_jsonschema(prefix, version):
schema_filename = "{0}_schema_v{1}.json".format(prefix, version)
config_source_dir = os.path.dirname(os.path.abspath(__file__))

if sys.platform == "win32":
config_source_dir = config_source_dir.replace('\\', '/')
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why this is necessary. On windows I would expect os.path to return the correct paths.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the if sys.platform == "win32": was here in the code I moved ...

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, I saw that, but I'm not sure why it's there either. This is mostly just a question, it may be necessary. If you can remove it and the appveyor tests pass (they run on windows) then I think it's safe to remove.


schema_file = os.path.join(config_source_dir, schema_filename)

with open(schema_file, "r") as schema_fh:
schema = json.load(schema_fh)
return schema


def _validate_against_schema(
config,
schema_filename,
schema,
format_checker=(),
path_prefix=None,
filename=None):
config_source_dir = os.path.dirname(os.path.abspath(__file__))

config_source_dir = os.path.dirname(os.path.abspath(__file__))
if sys.platform == "win32":
file_pre_fix = "///"
config_source_dir = config_source_dir.replace('\\', '/')
else:
file_pre_fix = "//"

resolver_full_path = "file:{}{}/".format(file_pre_fix, config_source_dir)
schema_file = os.path.join(config_source_dir, schema_filename)

with open(schema_file, "r") as schema_fh:
schema = json.load(schema_fh)

resolver = RefResolver(resolver_full_path, schema)
validation_output = Draft4Validator(
Expand Down
70 changes: 68 additions & 2 deletions tests/unit/config/interpolation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,32 @@
import pytest

from compose.config.interpolation import interpolate_environment_variables
from compose.config.validation import load_jsonschema
from compose.const import COMPOSEFILE_V2_0 as V2_0


@pytest.yield_fixture
def mock_env():
with mock.patch.dict(os.environ):
os.environ['USER'] = 'jenny'
os.environ['FOO'] = 'bar'
os.environ['RO'] = 'false'
os.environ['VOLUMES'] = "['/tmp/bar:/bar', '/tmp/foo:/foo']"
os.environ['VOLUMES_ITEM'] = '/tmp/foo:/foo'
os.environ['CPU_SHARES'] = '512'
os.environ['PORTS'] = "[ 8080 , 8089 ]"
os.environ['SINGLE_PORT'] = '8787'
yield


def test_interpolate_environment_variables_in_services(mock_env):
services = {
'servivea': {
'image': 'example:${USER}',
'entrypoint': '/bin/bash',
'volumes': ['$FOO:/target'],
'read_only': '${RO}',
'cpu_shares': '${CPU_SHARES}',
'logging': {
'driver': '${FOO}',
'options': {
Expand All @@ -33,7 +44,10 @@ def test_interpolate_environment_variables_in_services(mock_env):
expected = {
'servivea': {
'image': 'example:jenny',
'entrypoint': '/bin/bash',
'volumes': ['bar:/target'],
'read_only': False,
'cpu_shares': 512,
'logging': {
'driver': 'bar',
'options': {
Expand All @@ -42,7 +56,59 @@ def test_interpolate_environment_variables_in_services(mock_env):
}
}
}
assert interpolate_environment_variables(services, 'service') == expected
assert interpolate_environment_variables(services, 'service', load_jsonschema('service', V2_0)) == expected


def test_interpolate_environment_variables_arrays_in_services(mock_env):
services = {
'servivea': {
'image': 'example:${USER}',
'volumes': '${VOLUMES}',
'cpu_shares': '${CPU_SHARES}',
}
}
expected = {
'servivea': {
'image': 'example:jenny',
'volumes': ['/tmp/bar:/bar', '/tmp/foo:/foo'],
'cpu_shares': 512,
}
}
assert interpolate_environment_variables(services, 'service', load_jsonschema('service', V2_0)) == expected


def test_interpolate_environment_variables_array_element_in_services(mock_env):
services = {
'servivea': {
'image': 'example:${USER}',
'volumes': ['/tmp/bar:/bar', '${VOLUMES_ITEM}'],
'cpu_shares': '${CPU_SHARES}',
}
}
expected = {
'servivea': {
'image': 'example:jenny',
'volumes': ['/tmp/bar:/bar', '/tmp/foo:/foo'],
'cpu_shares': 512,
}
}
assert interpolate_environment_variables(services, 'service', load_jsonschema('service', V2_0)) == expected


def test_interpolate_environment_variables_array_numbers_in_services(mock_env):
services = {
'servivea': {
'expose': '${PORTS}',
'ports': [8080, '${SINGLE_PORT}']
}
}
expected = {
'servivea': {
'expose': [8080, 8089],
'ports': [8080, 8787]
}
}
assert interpolate_environment_variables(services, 'service', load_jsonschema('service', V2_0)) == expected


def test_interpolate_environment_variables_in_volumes(mock_env):
Expand All @@ -66,4 +132,4 @@ def test_interpolate_environment_variables_in_volumes(mock_env):
},
'other': {},
}
assert interpolate_environment_variables(volumes, 'volume') == expected
assert interpolate_environment_variables(volumes, 'volume', load_jsonschema('service', V2_0)) == expected