Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions compose/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,31 +413,35 @@ def merge_services(base, override):
return build_services(service_config)


def interpolate_config_section(filename, config, section, environment):
validate_config_section(filename, config, section)
return interpolate_environment_variables(config, section, environment)
def interpolate_config_section(config_file, config, section, environment):
validate_config_section(config_file.filename, config, section)
return interpolate_environment_variables(
config_file.version,
config,
section,
environment)


def process_config_file(config_file, environment, service_name=None):
services = interpolate_config_section(
config_file.filename,
config_file,
config_file.get_service_dicts(),
'service',
environment,)
environment)

if config_file.version in (V2_0, V2_1):
processed_config = dict(config_file.config)
processed_config['services'] = services
processed_config['volumes'] = interpolate_config_section(
config_file.filename,
config_file,
config_file.get_volumes(),
'volume',
environment,)
environment)
processed_config['networks'] = interpolate_config_section(
config_file.filename,
config_file,
config_file.get_networks(),
'network',
environment,)
environment)

if config_file.version == V1:
processed_config = services
Expand Down
74 changes: 57 additions & 17 deletions compose/config/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,35 @@
import six

from .errors import ConfigurationError
from compose.const import COMPOSEFILE_V1 as V1
from compose.const import COMPOSEFILE_V2_0 as V2_0


log = logging.getLogger(__name__)


def interpolate_environment_variables(config, section, environment):
class Interpolator(object):

def __init__(self, templater, mapping):
self.templater = templater
self.mapping = mapping

def interpolate(self, string):
try:
return self.templater(string).substitute(self.mapping)
except ValueError:
raise InvalidInterpolation(string)


def interpolate_environment_variables(version, config, section, environment):
if version in (V2_0, V1):
interpolator = Interpolator(Template, environment)
else:
interpolator = Interpolator(TemplateWithDefaults, environment)

def process_item(name, config_dict):
return dict(
(key, interpolate_value(name, key, val, section, environment))
(key, interpolate_value(name, key, val, section, interpolator))
for key, val in (config_dict or {}).items()
)

Expand All @@ -24,9 +45,9 @@ def process_item(name, config_dict):
)


def interpolate_value(name, config_key, value, section, mapping):
def interpolate_value(name, config_key, value, section, interpolator):
try:
return recursive_interpolate(value, mapping)
return recursive_interpolate(value, interpolator)
except InvalidInterpolation as e:
raise ConfigurationError(
'Invalid interpolation format for "{config_key}" option '
Expand All @@ -37,25 +58,44 @@ def interpolate_value(name, config_key, value, section, mapping):
string=e.string))


def recursive_interpolate(obj, mapping):
def recursive_interpolate(obj, interpolator):
if isinstance(obj, six.string_types):
return interpolate(obj, mapping)
elif isinstance(obj, dict):
return interpolator.interpolate(obj)
if isinstance(obj, dict):
return dict(
(key, recursive_interpolate(val, mapping))
(key, recursive_interpolate(val, interpolator))
for (key, val) in obj.items()
)
elif isinstance(obj, list):
return [recursive_interpolate(val, mapping) for val in obj]
else:
return obj
if isinstance(obj, list):
return [recursive_interpolate(val, interpolator) for val in obj]
return obj


def interpolate(string, mapping):
try:
return Template(string).substitute(mapping)
except ValueError:
raise InvalidInterpolation(string)
class TemplateWithDefaults(Template):
idpattern = r'[_a-z][_a-z0-9]*(?::?-[_a-z0-9]+)?'

# Modified from python2.7/string.py
def substitute(self, mapping):
# Helper function for .sub()
def convert(mo):
# Check the most common path first.
named = mo.group('named') or mo.group('braced')
if named is not None:
if ':-' in named:
var, _, default = named.partition(':-')
return mapping.get(var) or default
if '-' in named:
var, _, default = named.partition('-')
return mapping.get(var, default)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These 3 lines are the only real substantial modifications.

I also cut out the arg/kwarg handling and just replaced it with a mapping arg, since that's how we use it.

val = mapping[named]
return '%s' % (val,)
if mo.group('escaped') is not None:
return self.delimiter
if mo.group('invalid') is not None:
self._invalid(mo)
raise ValueError('Unrecognized named group in pattern',
self.pattern)
return self.pattern.sub(convert, self.template)


class InvalidInterpolation(Exception):
Expand Down
74 changes: 60 additions & 14 deletions tests/unit/config/interpolation_test.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@
from __future__ import absolute_import
from __future__ import unicode_literals

import os

import mock
import pytest

from compose.config.environment import Environment
from compose.config.interpolation import interpolate_environment_variables
from compose.config.interpolation import Interpolator
from compose.config.interpolation import InvalidInterpolation
from compose.config.interpolation import TemplateWithDefaults


@pytest.yield_fixture
@pytest.fixture
def mock_env():
with mock.patch.dict(os.environ):
os.environ['USER'] = 'jenny'
os.environ['FOO'] = 'bar'
yield
return Environment({'USER': 'jenny', 'FOO': 'bar'})


@pytest.fixture
def variable_mapping():
return Environment({'FOO': 'first', 'BAR': ''})


@pytest.fixture
def defaults_interpolator(variable_mapping):
return Interpolator(TemplateWithDefaults, variable_mapping).interpolate


def test_interpolate_environment_variables_in_services(mock_env):
Expand Down Expand Up @@ -43,9 +50,8 @@ def test_interpolate_environment_variables_in_services(mock_env):
}
}
}
assert interpolate_environment_variables(
services, 'service', Environment.from_env_file(None)
) == expected
value = interpolate_environment_variables("2.0", services, 'service', mock_env)
assert value == expected


def test_interpolate_environment_variables_in_volumes(mock_env):
Expand All @@ -69,6 +75,46 @@ def test_interpolate_environment_variables_in_volumes(mock_env):
},
'other': {},
}
assert interpolate_environment_variables(
volumes, 'volume', Environment.from_env_file(None)
) == expected
value = interpolate_environment_variables("2.0", volumes, 'volume', mock_env)
assert value == expected


def test_escaped_interpolation(defaults_interpolator):
assert defaults_interpolator('$${foo}') == '${foo}'


def test_invalid_interpolation(defaults_interpolator):
with pytest.raises(InvalidInterpolation):
defaults_interpolator('${')
with pytest.raises(InvalidInterpolation):
defaults_interpolator('$}')
with pytest.raises(InvalidInterpolation):
defaults_interpolator('${}')
with pytest.raises(InvalidInterpolation):
defaults_interpolator('${ }')
with pytest.raises(InvalidInterpolation):
defaults_interpolator('${ foo}')
with pytest.raises(InvalidInterpolation):
defaults_interpolator('${foo }')
with pytest.raises(InvalidInterpolation):
defaults_interpolator('${foo!}')


def test_interpolate_missing_no_default(defaults_interpolator):
assert defaults_interpolator("This ${missing} var") == "This var"
assert defaults_interpolator("This ${BAR} var") == "This var"


def test_interpolate_with_value(defaults_interpolator):
assert defaults_interpolator("This $FOO var") == "This first var"
assert defaults_interpolator("This ${FOO} var") == "This first var"


def test_interpolate_missing_with_default(defaults_interpolator):
assert defaults_interpolator("ok ${missing:-def}") == "ok def"
assert defaults_interpolator("ok ${missing-def}") == "ok def"


def test_interpolate_with_empty_and_default_value(defaults_interpolator):
assert defaults_interpolator("ok ${BAR:-def}") == "ok def"
assert defaults_interpolator("ok ${BAR-def}") == "ok "
36 changes: 0 additions & 36 deletions tests/unit/interpolation_test.py

This file was deleted.