Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions compose/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,23 @@ def merge_service_dicts(base, override):
if 'build' in override and 'image' in d:
del d['image']

for k in ALLOWED_KEYS:
if k not in ['environment', 'volumes']:
if k in override:
d[k] = override[k]
list_keys = ['ports', 'expose', 'external_links']

for key in list_keys:
if key in base or key in override:
d[key] = base.get(key, []) + override.get(key, [])

list_or_string_keys = ['dns', 'dns_search']

for key in list_or_string_keys:
if key in base or key in override:
d[key] = to_list(base.get(key)) + to_list(override.get(key))

already_merged_keys = ['environment', 'volumes'] + list_keys + list_or_string_keys

for k in set(ALLOWED_KEYS) - set(already_merged_keys):
if k in override:
d[k] = override[k]

return d

Expand Down Expand Up @@ -354,6 +367,15 @@ def expand_path(working_dir, path):
return os.path.abspath(os.path.join(working_dir, path))


def to_list(value):
if value is None:
return []
elif isinstance(value, six.string_types):
return [value]
else:
return value


def get_service_name_from_net(net_config):
if not net_config:
return
Expand Down
71 changes: 64 additions & 7 deletions tests/unit/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,40 +40,40 @@ def test_config_validation(self):
config.make_service_dict('foo', {'ports': ['8000']})


class MergeTest(unittest.TestCase):
def test_merge_volumes_empty(self):
class MergeVolumesTest(unittest.TestCase):
def test_empty(self):
service_dict = config.merge_service_dicts({}, {})
self.assertNotIn('volumes', service_dict)

def test_merge_volumes_no_override(self):
def test_no_override(self):
service_dict = config.merge_service_dicts(
{'volumes': ['/foo:/code', '/data']},
{},
)
self.assertEqual(set(service_dict['volumes']), set(['/foo:/code', '/data']))

def test_merge_volumes_no_base(self):
def test_no_base(self):
service_dict = config.merge_service_dicts(
{},
{'volumes': ['/bar:/code']},
)
self.assertEqual(set(service_dict['volumes']), set(['/bar:/code']))

def test_merge_volumes_override_explicit_path(self):
def test_override_explicit_path(self):
service_dict = config.merge_service_dicts(
{'volumes': ['/foo:/code', '/data']},
{'volumes': ['/bar:/code']},
)
self.assertEqual(set(service_dict['volumes']), set(['/bar:/code', '/data']))

def test_merge_volumes_add_explicit_path(self):
def test_add_explicit_path(self):
service_dict = config.merge_service_dicts(
{'volumes': ['/foo:/code', '/data']},
{'volumes': ['/bar:/code', '/quux:/data']},
)
self.assertEqual(set(service_dict['volumes']), set(['/bar:/code', '/quux:/data']))

def test_merge_volumes_remove_explicit_path(self):
def test_remove_explicit_path(self):
service_dict = config.merge_service_dicts(
{'volumes': ['/foo:/code', '/quux:/data']},
{'volumes': ['/bar:/code', '/data']},
Expand Down Expand Up @@ -114,6 +114,63 @@ def test_merge_build_or_image_override_with_other(self):
)


class MergeListsTest(unittest.TestCase):
def test_empty(self):
service_dict = config.merge_service_dicts({}, {})
self.assertNotIn('ports', service_dict)

def test_no_override(self):
service_dict = config.merge_service_dicts(
{'ports': ['10:8000', '9000']},
{},
)
self.assertEqual(set(service_dict['ports']), set(['10:8000', '9000']))

def test_no_base(self):
service_dict = config.merge_service_dicts(
{},
{'ports': ['10:8000', '9000']},
)
self.assertEqual(set(service_dict['ports']), set(['10:8000', '9000']))

def test_add_item(self):
service_dict = config.merge_service_dicts(
{'ports': ['10:8000', '9000']},
{'ports': ['20:8000']},
)
self.assertEqual(set(service_dict['ports']), set(['10:8000', '9000', '20:8000']))


class MergeStringsOrListsTest(unittest.TestCase):
def test_no_override(self):
service_dict = config.merge_service_dicts(
{'dns': '8.8.8.8'},
{},
)
self.assertEqual(set(service_dict['dns']), set(['8.8.8.8']))

def test_no_base(self):
service_dict = config.merge_service_dicts(
{},
{'dns': '8.8.8.8'},
)
self.assertEqual(set(service_dict['dns']), set(['8.8.8.8']))

def test_add_string(self):
service_dict = config.merge_service_dicts(
{'dns': ['8.8.8.8']},
{'dns': '9.9.9.9'},
)
self.assertEqual(set(service_dict['dns']), set(['8.8.8.8', '9.9.9.9']))

def test_add_list(self):
service_dict = config.merge_service_dicts(
{'dns': '8.8.8.8'},
{'dns': ['9.9.9.9']},
)
self.assertEqual(set(service_dict['dns']), set(['8.8.8.8', '9.9.9.9']))


class EnvTest(unittest.TestCase):
def test_parse_environment_as_list(self):
environment = [
Expand Down