Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions juju/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
from .bundle import get_charm_series, is_local_charm
from .client import client
from .errors import JujuApplicationConfigError, JujuError
from .origin import Channel, Source
from .origin import Channel
from .placement import parse as parse_placement
from .relation import Relation
from .status import derive_status
from .url import URL
from .utils import block_until
from .version import DEFAULT_ARCHITECTURE

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -691,13 +692,15 @@ async def refresh(
if charm_url_origin_result.error is not None:
err = charm_url_origin_result.error
raise JujuError(f'{err.code} : {err.message}')
origin = charm_url_origin_result.charm_origin

current_origin = charm_url_origin_result.charm_origin
if path is not None or (switch is not None and is_local_charm(switch)):
await self.local_refresh(origin, force, force_series,
await self.local_refresh(current_origin, force, force_series,
force_units, path or switch, resources)
return

origin = _refresh_origin(current_origin, channel, revision)

# If switch is not None at this point, that means it's a switch to a store charm
charm_url = switch or charm_url_origin_result.url
parsed_url = URL.parse(charm_url)
Expand All @@ -706,20 +709,6 @@ async def refresh(
if parsed_url.schema is None:
raise JujuError(f'A ch: or cs: schema is required for application refresh, given : {str(parsed_url)}')

if revision is not None:
origin.revision = revision

# Make the source-specific changes to the origin/channel/url
# (and also get the resources necessary to deploy the (destination) charm -- for later)
origin.source = Source.CHARM_HUB.value
if channel:
ch = Channel.parse(channel).normalize()
origin.risk = ch.risk
origin.track = ch.track

charmhub = self.model.charmhub
charm_resources = await charmhub.list_resources(charm_name)

# Resolve the given charm URLs with an optionally specified preferred channel.
# Channel provided via CharmOrigin.
resolved_charm_with_channel_results = await charms_facade.ResolveCharms(
Expand Down Expand Up @@ -761,8 +750,7 @@ async def refresh(
else:
_arg_res_filenames[res] = filename_or_rev

# Already prepped the charm_resources
# Now get the existing resources from the ResourcesFacade
# Get the existing resources from the ResourcesFacade
request_data = [client.Entity(self.tag)]
resources_facade = client.ResourcesFacade.from_connection(self.connection)
response = await resources_facade.ListResources(entities=request_data)
Expand All @@ -771,6 +759,9 @@ async def refresh(
for resource in response.results[0].resources
}

charmhub = self.model.charmhub
charm_resources = await charmhub.list_resources(charm_name)

# Compute the difference btw resources needed and the existing resources
resources_to_update = []
for resource in charm_resources:
Expand Down Expand Up @@ -917,6 +908,20 @@ async def get_metrics(self):
return await self.model.get_metrics(self.tag)


def _refresh_origin(current_origin: client.CharmOrigin, channel=None, revision=None) -> client.CharmOrigin:
if channel is not None:
channel = Channel.parse(channel).normalize()

return client.CharmOrigin(
source=current_origin.source,
track=channel.track if channel else current_origin.track,
risk=channel.risk if channel else current_origin.risk,
revision=revision if revision is not None else current_origin.revision,
base=current_origin.base,
architecture=current_origin.get('architecture', DEFAULT_ARCHITECTURE),
)


class ExposedEndpoint:
"""ExposedEndpoint stores the list of CIDRs and space names which should be
allowed access to the port ranges that the application has opened for a
Expand Down
12 changes: 12 additions & 0 deletions tests/integration/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,18 @@ async def test_local_refresh():
base=client.Base("20.04", "ubuntu"))


@base.bootstrapped
@pytest.mark.asyncio
async def test_refresh_revision():
async with base.CleanModel() as model:
app = await model.deploy('juju-qa-test', channel="latest/stable", revision=23)
# NOTE: juju-qa-test revision 26 has been released to this channel
await app.refresh(revision=25)

charm_url = URL.parse(app.data['charm-url'])
assert charm_url.revision == 25


@base.bootstrapped
@pytest.mark.asyncio
async def test_trusted():
Expand Down
55 changes: 54 additions & 1 deletion tests/unit/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import asyncio

from juju.model import Model
from juju.application import (Application, ExposedEndpoint)
from juju.application import Application, ExposedEndpoint, _refresh_origin
from juju.errors import JujuError
from juju.client import client
from juju.origin import Source


class TestExposeApplication(unittest.IsolatedAsyncioTestCase):
Expand Down Expand Up @@ -177,3 +179,54 @@ async def test_refresh_mutually_exclusive_kwargs(self, mock_conn):

with self.assertRaises(ValueError):
await app.refresh(switch="charm1", path="/path/to/charm2")

def test_refresh_origin(self):
current_origin = client.CharmOrigin(
source=str(Source.CHARM_HUB),
track="latest",
risk="stable",
revision=100,
base=client.Base("24.04", "ubuntu"),
architecture="amd64",
)

origin = _refresh_origin(current_origin, None, None)
self.assertEqual(origin, current_origin)

origin = _refresh_origin(current_origin, None, 101)
self.assertEqual(origin.revision, 101)
# Check source, base & arch do not change
self.assertEqual(origin.source, current_origin.source)
self.assertEqual(origin.base, current_origin.base)
self.assertEqual(origin.architecture, current_origin.architecture)

origin = _refresh_origin(current_origin, None, 0)
self.assertEqual(origin.revision, 0)
# Check source, base & arch do not change
self.assertEqual(origin.source, current_origin.source)
self.assertEqual(origin.base, current_origin.base)
self.assertEqual(origin.architecture, current_origin.architecture)

origin = _refresh_origin(current_origin, "12/edge", None)
self.assertEqual(origin.track, "12")
self.assertEqual(origin.risk, "edge")
# Check source, base & arch do not change
self.assertEqual(origin.source, current_origin.source)
self.assertEqual(origin.base, current_origin.base)
self.assertEqual(origin.architecture, current_origin.architecture)

def test_refresh_origin_drops_id_hash(self):
current_origin = client.CharmOrigin(
source=str(Source.CHARM_HUB),
track="latest",
risk="stable",
revision=100,
base=client.Base("24.04", "ubuntu"),
architecture="amd64",
id_="id",
hash_="hash",
)

origin = _refresh_origin(current_origin, None, None)
self.assertIsNone(origin.id_)
self.assertIsNone(origin.hash_)